From 7c4a1d92388e4dd0cf14c00ea98494ace5154f62 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Mon, 15 May 2023 16:02:56 +0200 Subject: [PATCH] use now model_display_name also for model's file name --- mlair/run_modules/experiment_setup.py | 4 ++-- mlair/run_modules/model_setup.py | 21 ++++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py index 8bbcfddf..e1b823fb 100644 --- a/mlair/run_modules/experiment_setup.py +++ b/mlair/run_modules/experiment_setup.py @@ -244,7 +244,7 @@ class ExperimentSetup(RunEnvironment): do_uncertainty_estimate: bool = None, do_bias_free_evaluation: bool = None, model_display_name: str = None, transformation_file: str = None, calculate_fresh_transformation: bool = None, snapshot_load_path: str = None, - create_snapshot: bool = None, snapshot_path: str = None, **kwargs): + create_snapshot: bool = None, snapshot_path: str = None, model_path: str = None, **kwargs): # create run framework super().__init__() @@ -299,7 +299,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data")) # set model path - self._set_param("model_path", None, os.path.join(experiment_path, "model")) + self._set_param("model_path", model_path, default=os.path.join(experiment_path, "model")) path_config.check_path_and_create(self.data_store.get("model_path")) # set plot path diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py index efeff062..819b6e0e 100644 --- a/mlair/run_modules/model_setup.py +++ b/mlair/run_modules/model_setup.py @@ -60,11 +60,12 @@ class ModelSetup(RunEnvironment): self.model = None exp_name = self.data_store.get("experiment_name") self.path = self.data_store.get("model_path") + self.model_display_name = self.data_store.get_default("model_display_name", default=None) self.scope = "model" - path = os.path.join(self.path, f"{exp_name}_%s") - self.model_name = path % "%s.h5" - self.checkpoint_name = path % "model-best.h5" - self.callbacks_name = path % "model-best-callbacks-%s.pickle" + path = os.path.join(self.path, f"{self.model_display_name or exp_name}%s") + self.model_path = path % "%s.h5" + self.checkpoint_name = path % "_model-best.h5" + self.callbacks_name = path % "_model-best-callbacks-%s.pickle" self._train_model = self.data_store.get("train_model") self._create_new_model = self.data_store.get("create_new_model") self._run() @@ -162,8 +163,8 @@ class ModelSetup(RunEnvironment): def load_model(self): """Try to load model from disk or skip if not possible.""" try: - self.model.load_model(self.model_name) - logging.info(f"reload model {self.model_name} from disk ...") + self.model.load_model(self.model_path) + logging.info(f"reload model {self.model_path} from disk ...") except OSError: logging.info('no local model to load...') @@ -195,14 +196,16 @@ class ModelSetup(RunEnvironment): """Load all model settings and store in data store.""" model_settings = self.model.get_settings() self.data_store.set_from_dict(model_settings, self.scope, log=True) - self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") - self.data_store.set("model_name", self.model_name, self.scope) + generic_model_name = self.data_store.get_default("model_name", self.scope, "my_model") + model_annotation = generic_model_name if self.model_display_name is None else "" + self.model_path = self.model_path % model_annotation + self.data_store.set("model_name", self.model_path, self.scope) def plot_model(self): # pragma: no cover """Plot model architecture as `<model_name>.pdf`.""" try: with tf.device("/cpu:0"): - file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" + file_name = f"{self.model_path.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) except Exception as e: logging.info(f"Can not plot model due to: {e}") -- GitLab