Skip to content
Snippets Groups Projects
Commit 10500dad authored by lukas leufen's avatar lukas leufen
Browse files

also add min_length parameter for train and val set

parent ca36674a
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!68Resolve "Minimal sample size in train-, val, test-set"
......@@ -35,7 +35,7 @@ class ExperimentSetup(RunEnvironment):
test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
test_min_length=None):
train_min_length=None, val_min_length=None, test_min_length=None):
# create run framework
super().__init__()
......@@ -100,19 +100,23 @@ class ExperimentSetup(RunEnvironment):
# train set parameters
self._set_param("start", train_start, default="1997-01-01", scope="general.train")
self._set_param("end", train_end, default="2007-12-31", scope="general.train")
self._set_param("min_length", train_min_length, default=90, scope="general.train")
# validation set parameters
self._set_param("start", val_start, default="2008-01-01", scope="general.val")
self._set_param("end", val_end, default="2009-12-31", scope="general.val")
self._set_param("min_length", val_min_length, default=90, scope="general.val")
# test set parameters
self._set_param("start", test_start, default="2010-01-01", scope="general.test")
self._set_param("end", test_end, default="2017-12-31", scope="general.test")
self._set_param("min_length", test_min_length, default=30, scope="general.test")
self._set_param("min_length", test_min_length, default=90, scope="general.test")
# train_val set parameters
self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val")
self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val")
train_val_min_length = sum([self.data_store.get("min_length", f"general.{s}") for s in ["train", "val"]])
self._set_param("min_length", train_val_min_length, default=180, scope="general.train_val")
# use all stations on all data sets (train, val, test)
self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True)
......
......@@ -85,12 +85,19 @@ class TestExperimentSetup:
# train parameters
assert data_store.get("start", "general.train") == "1997-01-01"
assert data_store.get("end", "general.train") == "2007-12-31"
assert data_store.get("min_length", "general.train") == 90
# validation parameters
assert data_store.get("start", "general.val") == "2008-01-01"
assert data_store.get("end", "general.val") == "2009-12-31"
assert data_store.get("min_length", "general.val") == 90
# test parameters
assert data_store.get("start", "general.test") == "2010-01-01"
assert data_store.get("end", "general.test") == "2017-12-31"
assert data_store.get("min_length", "general.test") == 90
# train_val parameters
assert data_store.get("start", "general.train_val") == "1997-01-01"
assert data_store.get("end", "general.train_val") == "2009-12-31"
assert data_store.get("min_length", "general.train_val") == 180
# use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general") is True
......@@ -104,7 +111,7 @@ class TestExperimentSetup:
interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05",
test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False,
fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True)
fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True, val_min_length=20)
exp_setup = ExperimentSetup(**kwargs)
data_store = exp_setup.data_store
# experiment setup
......@@ -139,12 +146,19 @@ class TestExperimentSetup:
# train parameters
assert data_store.get("start", "general.train") == "2000-01-01"
assert data_store.get("end", "general.train") == "2000-01-02"
assert data_store.get("min_length", "general.train") == 90
# validation parameters
assert data_store.get("start", "general.val") == "2000-01-03"
assert data_store.get("end", "general.val") == "2000-01-04"
assert data_store.get("min_length", "general.val") == 20
# test parameters
assert data_store.get("start", "general.test") == "2000-01-05"
assert data_store.get("end", "general.test") == "2000-01-06"
assert data_store.get("min_length", "general.test") == 90
# train_val parameters
assert data_store.get("start", "general.train_val") == "2000-01-01"
assert data_store.get("end", "general.train_val") == "2000-01-04"
assert data_store.get("min_length", "general.train_val") == 110
# use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment