Skip to content
Snippets Groups Projects
Commit 74b6c989 authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue076_feat_minimum-set-ample-size' into 'develop'

Resolve "Minimal sample size in  train-, val, test-set"

See merge request toar/machinelearningtools!68
parents b146b12b 54efbc20
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!68Resolve "Minimal sample size in train-, val, test-set"
Pipeline #31828 passed
...@@ -353,7 +353,8 @@ class DataPrep(object): ...@@ -353,7 +353,8 @@ class DataPrep(object):
non_nan_observation = self.observation.dropna(dim=dim) non_nan_observation = self.observation.dropna(dim=dim)
intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values)) intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values))
if len(intersect) == 0: min_length = self.kwargs.get("min_length", 0)
if len(intersect) < max(min_length, 1):
self.history = None self.history = None
self.label = None self.label = None
self.observation = None self.observation = None
......
...@@ -34,7 +34,8 @@ class ExperimentSetup(RunEnvironment): ...@@ -34,7 +34,8 @@ class ExperimentSetup(RunEnvironment):
limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None): create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
train_min_length=None, val_min_length=None, test_min_length=None):
# create run framework # create run framework
super().__init__() super().__init__()
...@@ -99,18 +100,23 @@ class ExperimentSetup(RunEnvironment): ...@@ -99,18 +100,23 @@ class ExperimentSetup(RunEnvironment):
# train set parameters # train set parameters
self._set_param("start", train_start, default="1997-01-01", scope="general.train") self._set_param("start", train_start, default="1997-01-01", scope="general.train")
self._set_param("end", train_end, default="2007-12-31", scope="general.train") self._set_param("end", train_end, default="2007-12-31", scope="general.train")
self._set_param("min_length", train_min_length, default=90, scope="general.train")
# validation set parameters # validation set parameters
self._set_param("start", val_start, default="2008-01-01", scope="general.val") self._set_param("start", val_start, default="2008-01-01", scope="general.val")
self._set_param("end", val_end, default="2009-12-31", scope="general.val") self._set_param("end", val_end, default="2009-12-31", scope="general.val")
self._set_param("min_length", val_min_length, default=90, scope="general.val")
# test set parameters # test set parameters
self._set_param("start", test_start, default="2010-01-01", scope="general.test") self._set_param("start", test_start, default="2010-01-01", scope="general.test")
self._set_param("end", test_end, default="2017-12-31", scope="general.test") self._set_param("end", test_end, default="2017-12-31", scope="general.test")
self._set_param("min_length", test_min_length, default=90, scope="general.test")
# train_val set parameters # train_val set parameters
self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val") self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val")
self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val") self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val")
train_val_min_length = sum([self.data_store.get("min_length", f"general.{s}") for s in ["train", "val"]])
self._set_param("min_length", train_val_min_length, default=180, scope="general.train_val")
# use all stations on all data sets (train, val, test) # use all stations on all data sets (train, val, test)
self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True) self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True)
......
...@@ -11,7 +11,7 @@ from src.join import EmptyQueryResult ...@@ -11,7 +11,7 @@ from src.join import EmptyQueryResult
from src.run_modules.run_environment import RunEnvironment from src.run_modules.run_environment import RunEnvironment
DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"]
DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length",
"station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"] "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"]
......
...@@ -287,6 +287,14 @@ class TestDataPrep: ...@@ -287,6 +287,14 @@ class TestDataPrep:
assert remaining_len == data.label.datetime.shape assert remaining_len == data.label.datetime.shape
assert remaining_len == data.observation.datetime.shape assert remaining_len == data.observation.datetime.shape
def test_remove_nan_too_short(self, data):
data.kwargs["min_length"] = 4000 # actual length of series is 3940
data.make_history_window('variables', -12, 'datetime')
data.make_labels('variables', 'o3', 'datetime', 3)
data.make_observation('variables', 'o3', 'datetime')
data.remove_nan('datetime')
assert not any([data.history, data.label, data.observation])
def test_create_index_array(self, data): def test_create_index_array(self, data):
index_array = data.create_index_array('window', range(1, 4)) index_array = data.create_index_array('window', range(1, 4))
assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None
......
...@@ -85,12 +85,19 @@ class TestExperimentSetup: ...@@ -85,12 +85,19 @@ class TestExperimentSetup:
# train parameters # train parameters
assert data_store.get("start", "general.train") == "1997-01-01" assert data_store.get("start", "general.train") == "1997-01-01"
assert data_store.get("end", "general.train") == "2007-12-31" assert data_store.get("end", "general.train") == "2007-12-31"
assert data_store.get("min_length", "general.train") == 90
# validation parameters # validation parameters
assert data_store.get("start", "general.val") == "2008-01-01" assert data_store.get("start", "general.val") == "2008-01-01"
assert data_store.get("end", "general.val") == "2009-12-31" assert data_store.get("end", "general.val") == "2009-12-31"
assert data_store.get("min_length", "general.val") == 90
# test parameters # test parameters
assert data_store.get("start", "general.test") == "2010-01-01" assert data_store.get("start", "general.test") == "2010-01-01"
assert data_store.get("end", "general.test") == "2017-12-31" assert data_store.get("end", "general.test") == "2017-12-31"
assert data_store.get("min_length", "general.test") == 90
# train_val parameters
assert data_store.get("start", "general.train_val") == "1997-01-01"
assert data_store.get("end", "general.train_val") == "2009-12-31"
assert data_store.get("min_length", "general.train_val") == 180
# use all stations on all data sets (train, val, test) # use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general") is True assert data_store.get("use_all_stations_on_all_data_sets", "general") is True
...@@ -104,7 +111,7 @@ class TestExperimentSetup: ...@@ -104,7 +111,7 @@ class TestExperimentSetup:
interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01", interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05",
test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False, test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False,
fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True) fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True, val_min_length=20)
exp_setup = ExperimentSetup(**kwargs) exp_setup = ExperimentSetup(**kwargs)
data_store = exp_setup.data_store data_store = exp_setup.data_store
# experiment setup # experiment setup
...@@ -139,12 +146,19 @@ class TestExperimentSetup: ...@@ -139,12 +146,19 @@ class TestExperimentSetup:
# train parameters # train parameters
assert data_store.get("start", "general.train") == "2000-01-01" assert data_store.get("start", "general.train") == "2000-01-01"
assert data_store.get("end", "general.train") == "2000-01-02" assert data_store.get("end", "general.train") == "2000-01-02"
assert data_store.get("min_length", "general.train") == 90
# validation parameters # validation parameters
assert data_store.get("start", "general.val") == "2000-01-03" assert data_store.get("start", "general.val") == "2000-01-03"
assert data_store.get("end", "general.val") == "2000-01-04" assert data_store.get("end", "general.val") == "2000-01-04"
assert data_store.get("min_length", "general.val") == 20
# test parameters # test parameters
assert data_store.get("start", "general.test") == "2000-01-05" assert data_store.get("start", "general.test") == "2000-01-05"
assert data_store.get("end", "general.test") == "2000-01-06" assert data_store.get("end", "general.test") == "2000-01-06"
assert data_store.get("min_length", "general.test") == 90
# train_val parameters
assert data_store.get("start", "general.train_val") == "2000-01-01"
assert data_store.get("end", "general.train_val") == "2000-01-04"
assert data_store.get("min_length", "general.train_val") == 110
# use all stations on all data sets (train, val, test) # use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False
......
...@@ -54,7 +54,7 @@ class TestPreProcessing: ...@@ -54,7 +54,7 @@ class TestPreProcessing:
assert obj_with_exp_setup.data_store.search_name("generator") == [] assert obj_with_exp_setup.data_store.search_name("generator") == []
obj_with_exp_setup.split_train_val_test() obj_with_exp_setup.split_train_val_test()
data_store = obj_with_exp_setup.data_store data_store = obj_with_exp_setup.data_store
expected_params = ["generator", "start", "end", "stations", "permute_data"] expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length"]
assert data_store.search_scope("general.train") == sorted(expected_params) assert data_store.search_scope("general.train") == sorted(expected_params)
assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test", assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
"general.train_val"]) "general.train_val"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment