diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index f34fb74dc6961c33d7b3c144d94f46d020c603fd..139969c76fe21498a2ac738426e33d1680b60a3f 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -273,6 +273,10 @@ class ExperimentSetup(RunEnvironment): self._set_param("experiment_path", exp_path) path_config.check_path_and_create(self.data_store.get("experiment_path")) + # set model path + self._set_param("model_path", None, os.path.join(exp_path, "model")) + path_config.check_path_and_create(self.data_store.get("model_path")) + # set plot path default_plot_path = os.path.join(exp_path, "plots") self._set_param("plot_path", plot_path, default=default_plot_path) @@ -283,6 +287,10 @@ class ExperimentSetup(RunEnvironment): self._set_param("forecast_path", forecast_path, default_forecast_path) path_config.check_path_and_create(self.data_store.get("forecast_path")) + # set logging path + self._set_param("logging_path", None, os.path.join(exp_path, "logging")) + path_config.check_path_and_create(self.data_store.get("logging_path")) + # setup for data self._set_param("stations", stations, default=DEFAULT_STATIONS) self._set_param("network", network, default="AIRBASE") diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 804c4b9403e3b61ca61522cfa9a56588b4776be5..618d7cd8bfff3a253b5f3084512e9ba72c603c8c 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -58,8 +58,9 @@ class ModelSetup(RunEnvironment): """Initialise and run model setup.""" super().__init__() self.model = None - path = self.data_store.get("experiment_path") + # path = self.data_store.get("experiment_path") exp_name = self.data_store.get("experiment_name") + path = self.data_store.get("model_path") self.scope = "model" self.path = os.path.join(path, f"{exp_name}_%s") self.model_name = self.path % "%s.h5" diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 143e2908b6398412211494b02bf84c3c741c18b2..dedcda0a6a3ff8fb9246bc6efe097eeb6b463999 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -39,7 +39,7 @@ class PostProcessing(RunEnvironment): * `generator` [train, val, test, train_val] * `forecast_path` [.] * `plot_path` [postprocessing] - * `experiment_path` [.] + * `model_path` [.] * `target_var` [.] * `sampling` [.] * `window_lead_time` [.] @@ -292,7 +292,7 @@ class PostProcessing(RunEnvironment): """Evaluate test score of model and save locally.""" test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), use_multiprocessing=False, verbose=0, steps=1) - path = self.data_store.get("experiment_path") + path = self.data_store.get("model_path") with open(os.path.join(path, "test_scores.txt"), "a") as f: for index, item in enumerate(test_score): logging.info(f"{self.model.metrics_names[index]}, {item}") diff --git a/src/run_modules/run_environment.py b/src/run_modules/run_environment.py index a0e619f364a060b3ed44639c6057046db197d84b..45d0a4a019b305d477838bd9ec4c5b6f920ac6fb 100644 --- a/src/run_modules/run_environment.py +++ b/src/run_modules/run_environment.py @@ -153,7 +153,7 @@ class RunEnvironment(object): def __find_file_pattern(self, name): counter = 0 - filename_pattern = os.path.join(self.data_store.get_default("experiment_path", os.path.realpath(".")), name) + filename_pattern = os.path.join(self.data_store.get_default("logging_path", os.path.realpath(".")), name) new_file = filename_pattern % counter while os.path.exists(new_file): counter += 1 diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 8cb4726fdc84ad10e62106c1d2bcbf899457e31d..8624b51512447924a1052ed47bc0d62f709781d1 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -194,7 +194,7 @@ class Training(RunEnvironment): :param lr_sc: learning rate object """ logging.debug("saving callbacks") - path = self.data_store.get("experiment_path") + path = self.data_store.get("model_path") with open(os.path.join(path, "history.json"), "w") as f: json.dump(history.history, f) if lr_sc: