Skip to content
Snippets Groups Projects
Commit 02e43daa authored by lukas leufen's avatar lukas leufen
Browse files

adjust logging path and model path inside experiment

parent 567de639
No related branches found
No related tags found
3 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!101Resolve "model folder in experiment"
Pipeline #39296 failed
......@@ -273,6 +273,10 @@ class ExperimentSetup(RunEnvironment):
self._set_param("experiment_path", exp_path)
path_config.check_path_and_create(self.data_store.get("experiment_path"))
# set model path
self._set_param("model_path", None, os.path.join(exp_path, "model"))
path_config.check_path_and_create(self.data_store.get("model_path"))
# set plot path
default_plot_path = os.path.join(exp_path, "plots")
self._set_param("plot_path", plot_path, default=default_plot_path)
......@@ -283,6 +287,10 @@ class ExperimentSetup(RunEnvironment):
self._set_param("forecast_path", forecast_path, default_forecast_path)
path_config.check_path_and_create(self.data_store.get("forecast_path"))
# set logging path
self._set_param("logging_path", None, os.path.join(exp_path, "logging"))
path_config.check_path_and_create(self.data_store.get("logging_path"))
# setup for data
self._set_param("stations", stations, default=DEFAULT_STATIONS)
self._set_param("network", network, default="AIRBASE")
......
......@@ -58,8 +58,9 @@ class ModelSetup(RunEnvironment):
"""Initialise and run model setup."""
super().__init__()
self.model = None
path = self.data_store.get("experiment_path")
# path = self.data_store.get("experiment_path")
exp_name = self.data_store.get("experiment_name")
path = self.data_store.get("model_path")
self.scope = "model"
self.path = os.path.join(path, f"{exp_name}_%s")
self.model_name = self.path % "%s.h5"
......
......@@ -39,7 +39,7 @@ class PostProcessing(RunEnvironment):
* `generator` [train, val, test, train_val]
* `forecast_path` [.]
* `plot_path` [postprocessing]
* `experiment_path` [.]
* `model_path` [.]
* `target_var` [.]
* `sampling` [.]
* `window_lead_time` [.]
......@@ -292,7 +292,7 @@ class PostProcessing(RunEnvironment):
"""Evaluate test score of model and save locally."""
test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),
use_multiprocessing=False, verbose=0, steps=1)
path = self.data_store.get("experiment_path")
path = self.data_store.get("model_path")
with open(os.path.join(path, "test_scores.txt"), "a") as f:
for index, item in enumerate(test_score):
logging.info(f"{self.model.metrics_names[index]}, {item}")
......
......@@ -153,7 +153,7 @@ class RunEnvironment(object):
def __find_file_pattern(self, name):
counter = 0
filename_pattern = os.path.join(self.data_store.get_default("experiment_path", os.path.realpath(".")), name)
filename_pattern = os.path.join(self.data_store.get_default("logging_path", os.path.realpath(".")), name)
new_file = filename_pattern % counter
while os.path.exists(new_file):
counter += 1
......
......@@ -194,7 +194,7 @@ class Training(RunEnvironment):
:param lr_sc: learning rate object
"""
logging.debug("saving callbacks")
path = self.data_store.get("experiment_path")
path = self.data_store.get("model_path")
with open(os.path.join(path, "history.json"), "w") as f:
json.dump(history.history, f)
if lr_sc:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment