diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 387feaa8ad02e9405513a683a13266ba0366c682..2aafe6c693b8d10fe450df510539e0aef8f0487e 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -35,7 +35,7 @@ class ExperimentSetup(RunEnvironment): test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, - test_min_length=None): + train_min_length=None, val_min_length=None, test_min_length=None): # create run framework super().__init__() @@ -100,19 +100,23 @@ class ExperimentSetup(RunEnvironment): # train set parameters self._set_param("start", train_start, default="1997-01-01", scope="general.train") self._set_param("end", train_end, default="2007-12-31", scope="general.train") + self._set_param("min_length", train_min_length, default=90, scope="general.train") # validation set parameters self._set_param("start", val_start, default="2008-01-01", scope="general.val") self._set_param("end", val_end, default="2009-12-31", scope="general.val") + self._set_param("min_length", val_min_length, default=90, scope="general.val") # test set parameters self._set_param("start", test_start, default="2010-01-01", scope="general.test") self._set_param("end", test_end, default="2017-12-31", scope="general.test") - self._set_param("min_length", test_min_length, default=30, scope="general.test") + self._set_param("min_length", test_min_length, default=90, scope="general.test") # train_val set parameters self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val") self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val") + train_val_min_length = sum([self.data_store.get("min_length", f"general.{s}") for s in ["train", "val"]]) + self._set_param("min_length", train_val_min_length, default=180, scope="general.train_val") # use all stations on all data sets (train, val, test) self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True) diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py index 894e4b552af4231ccc12fb85aaaebf5bbc23edf3..a3a83acf84e286d1f5da9b5caffa256fc0ca3327 100644 --- a/test/test_modules/test_experiment_setup.py +++ b/test/test_modules/test_experiment_setup.py @@ -85,12 +85,19 @@ class TestExperimentSetup: # train parameters assert data_store.get("start", "general.train") == "1997-01-01" assert data_store.get("end", "general.train") == "2007-12-31" + assert data_store.get("min_length", "general.train") == 90 # validation parameters assert data_store.get("start", "general.val") == "2008-01-01" assert data_store.get("end", "general.val") == "2009-12-31" + assert data_store.get("min_length", "general.val") == 90 # test parameters assert data_store.get("start", "general.test") == "2010-01-01" assert data_store.get("end", "general.test") == "2017-12-31" + assert data_store.get("min_length", "general.test") == 90 + # train_val parameters + assert data_store.get("start", "general.train_val") == "1997-01-01" + assert data_store.get("end", "general.train_val") == "2009-12-31" + assert data_store.get("min_length", "general.train_val") == 180 # use all stations on all data sets (train, val, test) assert data_store.get("use_all_stations_on_all_data_sets", "general") is True @@ -104,7 +111,7 @@ class TestExperimentSetup: interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05", test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False, - fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True) + fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True, val_min_length=20) exp_setup = ExperimentSetup(**kwargs) data_store = exp_setup.data_store # experiment setup @@ -139,12 +146,19 @@ class TestExperimentSetup: # train parameters assert data_store.get("start", "general.train") == "2000-01-01" assert data_store.get("end", "general.train") == "2000-01-02" + assert data_store.get("min_length", "general.train") == 90 # validation parameters assert data_store.get("start", "general.val") == "2000-01-03" assert data_store.get("end", "general.val") == "2000-01-04" + assert data_store.get("min_length", "general.val") == 20 # test parameters assert data_store.get("start", "general.test") == "2000-01-05" assert data_store.get("end", "general.test") == "2000-01-06" + assert data_store.get("min_length", "general.test") == 90 + # train_val parameters + assert data_store.get("start", "general.train_val") == "2000-01-01" + assert data_store.get("end", "general.train_val") == "2000-01-04" + assert data_store.get("min_length", "general.train_val") == 110 # use all stations on all data sets (train, val, test) assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False