From f65a442e73ba6595e7ed30b7b20a8bdca53e8c9a Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 12 Feb 2020 13:58:36 +0100 Subject: [PATCH] use create_new_model to choose if an existing model shall be overwritten or not, --- src/run_modules/experiment_setup.py | 18 +++++++++++------- src/run_modules/model_setup.py | 4 +++- src/run_modules/training.py | 3 ++- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 44f1a382..48f7c13e 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -31,15 +31,19 @@ class ExperimentSetup(RunEnvironment): statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None, window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, - test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None, - experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily"): + test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, + experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", + create_new_model=None): # create run framework super().__init__() # experiment setup self._set_param("data_path", helpers.prepare_host(sampling=sampling)) - self._set_param("trainable", trainable, default=False) + self._set_param("create_new_model", create_new_model, default=True) + if self.data_store.get("create_new_model", "general"): + trainable = True + self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) # set experiment name @@ -85,19 +89,19 @@ class ExperimentSetup(RunEnvironment): self._set_param("interpolate_method", interpolate_method, default='linear') self._set_param("limit_nan_fill", limit_nan_fill, default=1) - # train parameters + # train set parameters self._set_param("start", train_start, default="1997-01-01", scope="general.train") self._set_param("end", train_end, default="2007-12-31", scope="general.train") - # validation parameters + # validation set parameters self._set_param("start", val_start, default="2008-01-01", scope="general.val") self._set_param("end", val_end, default="2009-12-31", scope="general.val") - # test parameters + # test set parameters self._set_param("start", test_start, default="2010-01-01", scope="general.test") self._set_param("end", test_end, default="2017-12-31", scope="general.test") - # train_val parameters + # train_val set parameters self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val") self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val") diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index 2c73dad4..c14298d7 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -32,6 +32,8 @@ class ModelSetup(RunEnvironment): self.model_name = self.path % "%s.h5" self.checkpoint_name = self.path % "model-best.h5" self.callbacks_name = self.path % "model-best-callbacks-%s.pickle" + self._trainable = self.data_store.get("trainable", "general") + self._create_new_model = self.data_store.get("create_new_model", "general") self._run() def _run(self): @@ -46,7 +48,7 @@ class ModelSetup(RunEnvironment): self.plot_model() # load weights if no training shall be performed - if self.data_store.get("trainable", self.scope) is False: + if not self._trainable and not self._create_new_model: self.load_weights() # create checkpoint diff --git a/src/run_modules/training.py b/src/run_modules/training.py index ff2cffcd..0b11da8d 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -29,6 +29,7 @@ class Training(RunEnvironment): self.hist = self.data_store.get("hist", "general.model") self.experiment_name = self.data_store.get("experiment_name", "general") self._trainable = self.data_store.get("trainable", "general") + self._create_new_model = self.data_store.get("create_new_model", "general") self._run() def _run(self) -> None: @@ -86,7 +87,7 @@ class Training(RunEnvironment): locally stored information and the corresponding model and proceed with the already started training. """ logging.info(f"Train with {len(self.train_set)} mini batches.") - if not os.path.exists(self.checkpoint.filepath): + if not os.path.exists(self.checkpoint.filepath) or self._create_new_model: history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), steps_per_epoch=len(self.train_set), epochs=self.epochs, -- GitLab