Skip to content
Snippets Groups Projects
Commit f65a442e authored by lukas leufen's avatar lukas leufen
Browse files

use create_new_model to choose if an existing model shall be overwritten or not,

parent 0d879b56
No related branches found
No related tags found
2 merge requests!50release for v0.7.0,!39use create_new_model to choose if an existing model shall be overwritten or not,
Pipeline #29408 passed
......@@ -31,15 +31,19 @@ class ExperimentSetup(RunEnvironment):
statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None,
window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None,
limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily"):
test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
create_new_model=None):
# create run framework
super().__init__()
# experiment setup
self._set_param("data_path", helpers.prepare_host(sampling=sampling))
self._set_param("trainable", trainable, default=False)
self._set_param("create_new_model", create_new_model, default=True)
if self.data_store.get("create_new_model", "general"):
trainable = True
self._set_param("trainable", trainable, default=True)
self._set_param("fraction_of_training", fraction_of_train, default=0.8)
# set experiment name
......@@ -85,19 +89,19 @@ class ExperimentSetup(RunEnvironment):
self._set_param("interpolate_method", interpolate_method, default='linear')
self._set_param("limit_nan_fill", limit_nan_fill, default=1)
# train parameters
# train set parameters
self._set_param("start", train_start, default="1997-01-01", scope="general.train")
self._set_param("end", train_end, default="2007-12-31", scope="general.train")
# validation parameters
# validation set parameters
self._set_param("start", val_start, default="2008-01-01", scope="general.val")
self._set_param("end", val_end, default="2009-12-31", scope="general.val")
# test parameters
# test set parameters
self._set_param("start", test_start, default="2010-01-01", scope="general.test")
self._set_param("end", test_end, default="2017-12-31", scope="general.test")
# train_val parameters
# train_val set parameters
self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val")
self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val")
......
......@@ -32,6 +32,8 @@ class ModelSetup(RunEnvironment):
self.model_name = self.path % "%s.h5"
self.checkpoint_name = self.path % "model-best.h5"
self.callbacks_name = self.path % "model-best-callbacks-%s.pickle"
self._trainable = self.data_store.get("trainable", "general")
self._create_new_model = self.data_store.get("create_new_model", "general")
self._run()
def _run(self):
......@@ -46,7 +48,7 @@ class ModelSetup(RunEnvironment):
self.plot_model()
# load weights if no training shall be performed
if self.data_store.get("trainable", self.scope) is False:
if not self._trainable and not self._create_new_model:
self.load_weights()
# create checkpoint
......
......@@ -29,6 +29,7 @@ class Training(RunEnvironment):
self.hist = self.data_store.get("hist", "general.model")
self.experiment_name = self.data_store.get("experiment_name", "general")
self._trainable = self.data_store.get("trainable", "general")
self._create_new_model = self.data_store.get("create_new_model", "general")
self._run()
def _run(self) -> None:
......@@ -86,7 +87,7 @@ class Training(RunEnvironment):
locally stored information and the corresponding model and proceed with the already started training.
"""
logging.info(f"Train with {len(self.train_set)} mini batches.")
if not os.path.exists(self.checkpoint.filepath):
if not os.path.exists(self.checkpoint.filepath) or self._create_new_model:
history = self.model.fit_generator(generator=self.train_set.distribute_on_batches(),
steps_per_epoch=len(self.train_set),
epochs=self.epochs,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment