diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 7049af591cdf73434459d4c3f5f6c11e80ab64c0..2c73dad4bc57e529a417a2d4f4e476dfd7624c5a 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -28,8 +28,10 @@ class ModelSetup(RunEnvironment): path = self.data_store.get("experiment_path", "general") exp_name = self.data_store.get("experiment_name", "general") self.scope = "general.model" - self.checkpoint_name = os.path.join(path, f"{exp_name}_model-best.h5") - self.callbacks_name = os.path.join(path, f"{exp_name}_model-best-callbacks-%s.pickle") + self.path = os.path.join(path, f"{exp_name}_%s") + self.model_name = self.path % "%s.h5" + self.checkpoint_name = self.path % "model-best.h5" + self.callbacks_name = self.path % "model-best-callbacks-%s.pickle" self._run() def _run(self): @@ -79,8 +81,8 @@ class ModelSetup(RunEnvironment): def load_weights(self): try: - self.model.load_weights(self.checkpoint_name) - logging.info('reload weights...') + self.model.load_weights(self.model_name) + logging.info(f"reload weights from model {self.model_name} ...") except OSError: logging.info('no weights to reload...') @@ -93,12 +95,12 @@ class ModelSetup(RunEnvironment): def get_model_settings(self): model_settings = self.model.get_settings() self.data_store.set_args_from_dict(model_settings, self.scope) + self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") + self.data_store.set("model_name", self.model_name, self.scope) def plot_model(self): # pragma: no cover with tf.device("/cpu:0"): - path = self.data_store.get("experiment_path", "general") - name = self.data_store.get("experiment_name", "general") + "_model.pdf" - file_name = os.path.join(path, name) + file_name = f"{self.model_name.split(sep='.')[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 03d2e36e8662a573b96c970747e9fe4445244e9b..3c50799b939da2bc8517d1e58dd2c73baebe367b 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -55,10 +55,8 @@ class PostProcessing(RunEnvironment): try: model = self.data_store.get("best_model", "general") except NameNotFoundInDataStore: - logging.info("no model saved in data store. trying to load model from experiment") - path = self.data_store.get("experiment_path", "general") - name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5" - model_name = os.path.join(path, name) + logging.info("no model saved in data store. trying to load model from experiment path") + model_name = self.data_store.get("model_name", "general.model") model = keras.models.load_model(model_name) return model diff --git a/src/run_modules/training.py b/src/run_modules/training.py index e2a98f27c65e6050b0edae2bbc178abbf97ab646..7eb1cd7ac93ad7ea438a738bcf2ab5c1dd6397a2 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -117,11 +117,9 @@ class Training(RunEnvironment): def save_model(self) -> None: """ - save model in local experiment directory. Model is named as <experiment_name>_my_model.h5 . + save model in local experiment directory. Model is named as <experiment_name>_<custom_model_name>.h5 . """ - path = self.data_store.get("experiment_path", "general") - name = f"{self.data_store.get('experiment_name', 'general')}_my_model.h5" - model_name = os.path.join(path, name) + model_name = self.data_store.get("model_name", "general.model") logging.debug(f"save best model to {model_name}") self.model.save(model_name) self.data_store.set("best_model", self.model, "general")