From 02e43daade99083c4c0070ebb0875621dac00379 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Fri, 19 Jun 2020 10:35:52 +0200
Subject: [PATCH] adjust logging path and model path inside experiment

---
 src/run_modules/experiment_setup.py | 8 ++++++++
 src/run_modules/model_setup.py      | 3 ++-
 src/run_modules/post_processing.py  | 4 ++--
 src/run_modules/run_environment.py  | 2 +-
 src/run_modules/training.py         | 2 +-
 5 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index f34fb74d..139969c7 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -273,6 +273,10 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("experiment_path", exp_path)
         path_config.check_path_and_create(self.data_store.get("experiment_path"))
 
+        # set model path
+        self._set_param("model_path", None, os.path.join(exp_path, "model"))
+        path_config.check_path_and_create(self.data_store.get("model_path"))
+
         # set plot path
         default_plot_path = os.path.join(exp_path, "plots")
         self._set_param("plot_path", plot_path, default=default_plot_path)
@@ -283,6 +287,10 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("forecast_path", forecast_path, default_forecast_path)
         path_config.check_path_and_create(self.data_store.get("forecast_path"))
 
+        # set logging path
+        self._set_param("logging_path", None, os.path.join(exp_path, "logging"))
+        path_config.check_path_and_create(self.data_store.get("logging_path"))
+
         # setup for data
         self._set_param("stations", stations, default=DEFAULT_STATIONS)
         self._set_param("network", network, default="AIRBASE")
diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py
index 804c4b94..618d7cd8 100644
--- a/src/run_modules/model_setup.py
+++ b/src/run_modules/model_setup.py
@@ -58,8 +58,9 @@ class ModelSetup(RunEnvironment):
         """Initialise and run model setup."""
         super().__init__()
         self.model = None
-        path = self.data_store.get("experiment_path")
+        # path = self.data_store.get("experiment_path")
         exp_name = self.data_store.get("experiment_name")
+        path = self.data_store.get("model_path")
         self.scope = "model"
         self.path = os.path.join(path, f"{exp_name}_%s")
         self.model_name = self.path % "%s.h5"
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 143e2908..dedcda0a 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -39,7 +39,7 @@ class PostProcessing(RunEnvironment):
         * `generator` [train, val, test, train_val]
         * `forecast_path` [.]
         * `plot_path` [postprocessing]
-        * `experiment_path` [.]
+        * `model_path` [.]
         * `target_var` [.]
         * `sampling` [.]
         * `window_lead_time` [.]
@@ -292,7 +292,7 @@ class PostProcessing(RunEnvironment):
         """Evaluate test score of model and save locally."""
         test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),
                                                    use_multiprocessing=False, verbose=0, steps=1)
-        path = self.data_store.get("experiment_path")
+        path = self.data_store.get("model_path")
         with open(os.path.join(path, "test_scores.txt"), "a") as f:
             for index, item in enumerate(test_score):
                 logging.info(f"{self.model.metrics_names[index]}, {item}")
diff --git a/src/run_modules/run_environment.py b/src/run_modules/run_environment.py
index a0e619f3..45d0a4a0 100644
--- a/src/run_modules/run_environment.py
+++ b/src/run_modules/run_environment.py
@@ -153,7 +153,7 @@ class RunEnvironment(object):
 
     def __find_file_pattern(self, name):
         counter = 0
-        filename_pattern = os.path.join(self.data_store.get_default("experiment_path", os.path.realpath(".")), name)
+        filename_pattern = os.path.join(self.data_store.get_default("logging_path", os.path.realpath(".")), name)
         new_file = filename_pattern % counter
         while os.path.exists(new_file):
             counter += 1
diff --git a/src/run_modules/training.py b/src/run_modules/training.py
index 8cb4726f..8624b515 100644
--- a/src/run_modules/training.py
+++ b/src/run_modules/training.py
@@ -194,7 +194,7 @@ class Training(RunEnvironment):
         :param lr_sc: learning rate object
         """
         logging.debug("saving callbacks")
-        path = self.data_store.get("experiment_path")
+        path = self.data_store.get("model_path")
         with open(os.path.join(path, "history.json"), "w") as f:
             json.dump(history.history, f)
         if lr_sc:
-- 
GitLab