diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index bb4801acfb5c3a643aecbcfad9cfdb758258d0ef..1a5dd9da38520d6d732253015e8a67325e24c460 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -27,6 +27,7 @@ class AbstractModelClass(ABC): self.__model = None self.__loss = None + self.model_name = self.__class__.__name__ def __getattr__(self, name: str) -> Any: diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 01a565e3db284e56bf0b8c94420b71268fd21a80..d3f6a56a2ec42e41454f0ced8853cb8d1001e495 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -527,7 +527,7 @@ class PlotTimeSeries(RunEnvironment): for i_year in range(end - start + 1): data_year = data.sel(index=f"{start + i_year}") for i_half_of_year in range(factor): - pos = 2 * i_year + i_half_of_year + pos = factor * i_year + i_half_of_year plot_data = self._create_plot_data(data_year, factor, i_half_of_year) self._plot_orig(axes[pos], plot_data) self._plot_ahead(axes[pos], plot_data) diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 7049af591cdf73434459d4c3f5f6c11e80ab64c0..2c73dad4bc57e529a417a2d4f4e476dfd7624c5a 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -28,8 +28,10 @@ class ModelSetup(RunEnvironment): path = self.data_store.get("experiment_path", "general") exp_name = self.data_store.get("experiment_name", "general") self.scope = "general.model" - self.checkpoint_name = os.path.join(path, f"{exp_name}_model-best.h5") - self.callbacks_name = os.path.join(path, f"{exp_name}_model-best-callbacks-%s.pickle") + self.path = os.path.join(path, f"{exp_name}_%s") + self.model_name = self.path % "%s.h5" + self.checkpoint_name = self.path % "model-best.h5" + self.callbacks_name = self.path % "model-best-callbacks-%s.pickle" self._run() def _run(self): @@ -79,8 +81,8 @@ class ModelSetup(RunEnvironment): def load_weights(self): try: - self.model.load_weights(self.checkpoint_name) - logging.info('reload weights...') + self.model.load_weights(self.model_name) + logging.info(f"reload weights from model {self.model_name} ...") except OSError: logging.info('no weights to reload...') @@ -93,12 +95,12 @@ class ModelSetup(RunEnvironment): def get_model_settings(self): model_settings = self.model.get_settings() self.data_store.set_args_from_dict(model_settings, self.scope) + self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") + self.data_store.set("model_name", self.model_name, self.scope) def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): - path = self.data_store.get("experiment_path", "general") - name = self.data_store.get("experiment_name", "general") + "_model.pdf" - file_name = os.path.join(path, name) + file_name = f"{self.model_name.split(sep='.')[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 03d2e36e8662a573b96c970747e9fe4445244e9b..3c50799b939da2bc8517d1e58dd2c73baebe367b 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -55,10 +55,8 @@ class PostProcessing(RunEnvironment): try: model = self.data_store.get("best_model", "general") except NameNotFoundInDataStore: - logging.info("no model saved in data store. trying to load model from experiment") - path = self.data_store.get("experiment_path", "general") - name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5" - model_name = os.path.join(path, name) + logging.info("no model saved in data store. trying to load model from experiment path") + model_name = self.data_store.get("model_name", "general.model") model = keras.models.load_model(model_name) return model diff --git a/src/run_modules/training.py b/src/run_modules/training.py index e2a98f27c65e6050b0edae2bbc178abbf97ab646..ff2cffcdf01fd9e917bf1120984c6b65e1f5a13d 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -28,6 +28,7 @@ class Training(RunEnvironment): self.lr_sc = self.data_store.get("lr_decay", "general.model") self.hist = self.data_store.get("hist", "general.model") self.experiment_name = self.data_store.get("experiment_name", "general") + self._trainable = self.data_store.get("trainable", "general") self._run() def _run(self) -> None: @@ -44,8 +45,11 @@ class Training(RunEnvironment): """ self.set_generators() self.make_predict_function() - self.train() - self.save_model() + if self._trainable: + self.train() + self.save_model() + else: + logging.info("No training has started, because trainable parameter was false.") def make_predict_function(self) -> None: """ @@ -117,11 +121,9 @@ class Training(RunEnvironment): def save_model(self) -> None: """ - save model in local experiment directory. Model is named as <experiment_name>_my_model.h5 . + save model in local experiment directory. Model is named as <experiment_name>_<custom_model_name>.h5 . """ - path = self.data_store.get("experiment_path", "general") - name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5" - model_name = os.path.join(path, name) + model_name = self.data_store.get("model_name", "general.model") logging.debug(f"save best model to {model_name}") self.model.save(model_name) self.data_store.set("best_model", self.model, "general")