diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 44f1a3821dfacba146a0aabe8fb0254068d9e6d3..48f7c13e51622d7d52405b73c0a6f57537b5b476 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -31,15 +31,19 @@ class ExperimentSetup(RunEnvironment): statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None, window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, - test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None, - experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily"): + test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, + experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", + create_new_model=None): # create run framework super().__init__() # experiment setup self._set_param("data_path", helpers.prepare_host(sampling=sampling)) - self._set_param("trainable", trainable, default=False) + self._set_param("create_new_model", create_new_model, default=True) + if self.data_store.get("create_new_model", "general"): + trainable = True + self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) # set experiment name @@ -85,19 +89,19 @@ class ExperimentSetup(RunEnvironment): self._set_param("interpolate_method", interpolate_method, default='linear') self._set_param("limit_nan_fill", limit_nan_fill, default=1) - # train parameters + # train set parameters self._set_param("start", train_start, default="1997-01-01", scope="general.train") self._set_param("end", train_end, default="2007-12-31", scope="general.train") - # validation parameters + # validation set parameters self._set_param("start", val_start, default="2008-01-01", scope="general.val") self._set_param("end", val_end, default="2009-12-31", scope="general.val") - # test parameters + # test set parameters self._set_param("start", test_start, default="2010-01-01", scope="general.test") self._set_param("end", test_end, default="2017-12-31", scope="general.test") - # train_val parameters + # train_val set parameters self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val") self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val") diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 2c73dad4bc57e529a417a2d4f4e476dfd7624c5a..c14298d7d21f63cdc4465c1ed8e8bb30868b3c1a 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -32,6 +32,8 @@ class ModelSetup(RunEnvironment): self.model_name = self.path % "%s.h5" self.checkpoint_name = self.path % "model-best.h5" self.callbacks_name = self.path % "model-best-callbacks-%s.pickle" + self._trainable = self.data_store.get("trainable", "general") + self._create_new_model = self.data_store.get("create_new_model", "general") self._run() def _run(self): @@ -46,7 +48,7 @@ class ModelSetup(RunEnvironment): self.plot_model() # load weights if no training shall be performed - if self.data_store.get("trainable", self.scope) is False: + if not self._trainable and not self._create_new_model: self.load_weights() # create checkpoint diff --git a/src/run_modules/training.py b/src/run_modules/training.py index ff2cffcdf01fd9e917bf1120984c6b65e1f5a13d..0b11da8d8f9a23c51d787f00d00a74d7517ea3b1 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -29,6 +29,7 @@ class Training(RunEnvironment): self.hist = self.data_store.get("hist", "general.model") self.experiment_name = self.data_store.get("experiment_name", "general") self._trainable = self.data_store.get("trainable", "general") + self._create_new_model = self.data_store.get("create_new_model", "general") self._run() def _run(self) -> None: @@ -86,7 +87,7 @@ class Training(RunEnvironment): locally stored information and the corresponding model and proceed with the already started training. """ logging.info(f"Train with {len(self.train_set)} mini batches.") - if not os.path.exists(self.checkpoint.filepath): + if not os.path.exists(self.checkpoint.filepath) or self._create_new_model: history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), steps_per_epoch=len(self.train_set), epochs=self.epochs,