From 7c4a1d92388e4dd0cf14c00ea98494ace5154f62 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Mon, 15 May 2023 16:02:56 +0200
Subject: [PATCH] use now model_display_name also for model's file name

---
 mlair/run_modules/experiment_setup.py |  4 ++--
 mlair/run_modules/model_setup.py      | 21 ++++++++++++---------
 2 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index 8bbcfddf..e1b823fb 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -244,7 +244,7 @@ class ExperimentSetup(RunEnvironment):
                  do_uncertainty_estimate: bool = None, do_bias_free_evaluation: bool = None,
                  model_display_name: str = None, transformation_file: str = None,
                  calculate_fresh_transformation: bool = None, snapshot_load_path: str = None,
-                 create_snapshot: bool = None, snapshot_path: str = None, **kwargs):
+                 create_snapshot: bool = None, snapshot_path: str = None, model_path: str = None, **kwargs):
 
         # create run framework
         super().__init__()
@@ -299,7 +299,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("batch_path", batch_path, default=os.path.join(experiment_path, "batch_data"))
 
         # set model path
-        self._set_param("model_path", None, os.path.join(experiment_path, "model"))
+        self._set_param("model_path", model_path, default=os.path.join(experiment_path, "model"))
         path_config.check_path_and_create(self.data_store.get("model_path"))
 
         # set plot path
diff --git a/mlair/run_modules/model_setup.py b/mlair/run_modules/model_setup.py
index efeff062..819b6e0e 100644
--- a/mlair/run_modules/model_setup.py
+++ b/mlair/run_modules/model_setup.py
@@ -60,11 +60,12 @@ class ModelSetup(RunEnvironment):
         self.model = None
         exp_name = self.data_store.get("experiment_name")
         self.path = self.data_store.get("model_path")
+        self.model_display_name = self.data_store.get_default("model_display_name", default=None)
         self.scope = "model"
-        path = os.path.join(self.path, f"{exp_name}_%s")
-        self.model_name = path % "%s.h5"
-        self.checkpoint_name = path % "model-best.h5"
-        self.callbacks_name = path % "model-best-callbacks-%s.pickle"
+        path = os.path.join(self.path, f"{self.model_display_name or exp_name}%s")
+        self.model_path = path % "%s.h5"
+        self.checkpoint_name = path % "_model-best.h5"
+        self.callbacks_name = path % "_model-best-callbacks-%s.pickle"
         self._train_model = self.data_store.get("train_model")
         self._create_new_model = self.data_store.get("create_new_model")
         self._run()
@@ -162,8 +163,8 @@ class ModelSetup(RunEnvironment):
     def load_model(self):
         """Try to load model from disk or skip if not possible."""
         try:
-            self.model.load_model(self.model_name)
-            logging.info(f"reload model {self.model_name} from disk ...")
+            self.model.load_model(self.model_path)
+            logging.info(f"reload model {self.model_path} from disk ...")
         except OSError:
             logging.info('no local model to load...')
 
@@ -195,14 +196,16 @@ class ModelSetup(RunEnvironment):
         """Load all model settings and store in data store."""
         model_settings = self.model.get_settings()
         self.data_store.set_from_dict(model_settings, self.scope, log=True)
-        self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model")
-        self.data_store.set("model_name", self.model_name, self.scope)
+        generic_model_name = self.data_store.get_default("model_name", self.scope, "my_model")
+        model_annotation = generic_model_name if self.model_display_name is None else ""
+        self.model_path = self.model_path % model_annotation
+        self.data_store.set("model_name", self.model_path, self.scope)
 
     def plot_model(self):  # pragma: no cover
         """Plot model architecture as `<model_name>.pdf`."""
         try:
             with tf.device("/cpu:0"):
-                file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf"
+                file_name = f"{self.model_path.rsplit('.', 1)[0]}.pdf"
                 keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True)
         except Exception as e:
             logging.info(f"Can not plot model due to: {e}")
-- 
GitLab