diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 3fae09306ab65d18f19d770b525cdc2296215bcd..e3186778b94375ba1d39fa87ba7d2980c785581e 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -353,7 +353,8 @@ class DataPrep(object): non_nan_observation = self.observation.dropna(dim=dim) intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values)) - if len(intersect) == 0: + min_length = self.kwargs.get("min_length", 0) + if len(intersect) < max(min_length, 1): self.history = None self.label = None self.observation = None diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 56c22a81e48421438816855770b7477e84e3a8d8..2aafe6c693b8d10fe450df510539e0aef8f0487e 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -34,7 +34,8 @@ class ExperimentSetup(RunEnvironment): limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", - create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None): + create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, + train_min_length=None, val_min_length=None, test_min_length=None): # create run framework super().__init__() @@ -99,18 +100,23 @@ class ExperimentSetup(RunEnvironment): # train set parameters self._set_param("start", train_start, default="1997-01-01", scope="general.train") self._set_param("end", train_end, default="2007-12-31", scope="general.train") + self._set_param("min_length", train_min_length, default=90, scope="general.train") # validation set parameters self._set_param("start", val_start, default="2008-01-01", scope="general.val") self._set_param("end", val_end, default="2009-12-31", scope="general.val") + self._set_param("min_length", val_min_length, default=90, scope="general.val") # test set parameters self._set_param("start", test_start, default="2010-01-01", scope="general.test") self._set_param("end", test_end, default="2017-12-31", scope="general.test") + self._set_param("min_length", test_min_length, default=90, scope="general.test") # train_val set parameters self._set_param("start", self.data_store.get("start", "general.train"), scope="general.train_val") self._set_param("end", self.data_store.get("end", "general.val"), scope="general.train_val") + train_val_min_length = sum([self.data_store.get("min_length", f"general.{s}") for s in ["train", "val"]]) + self._set_param("min_length", train_val_min_length, default=180, scope="general.train_val") # use all stations on all data sets (train, val, test) self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True) diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 1d014c9e6f4fc0a9168c4d3d31b1141c39fff2a1..20286bc43b3227291c66c7844ad43792a7a28480 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -11,7 +11,7 @@ from src.join import EmptyQueryResult from src.run_modules.run_environment import RunEnvironment DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] -DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", +DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_time", "statistics_per_var", "min_length", "station_type", "overwrite_local_data", "start", "end", "sampling", "transformation"] diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 91719f3dd16326ee6281c4db8ef3aa87e238d70f..85c4420609a466ff5f3eeb3d46cb6bb07fe9c30a 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -287,6 +287,14 @@ class TestDataPrep: assert remaining_len == data.label.datetime.shape assert remaining_len == data.observation.datetime.shape + def test_remove_nan_too_short(self, data): + data.kwargs["min_length"] = 4000 # actual length of series is 3940 + data.make_history_window('variables', -12, 'datetime') + data.make_labels('variables', 'o3', 'datetime', 3) + data.make_observation('variables', 'o3', 'datetime') + data.remove_nan('datetime') + assert not any([data.history, data.label, data.observation]) + def test_create_index_array(self, data): index_array = data.create_index_array('window', range(1, 4)) assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py index 894e4b552af4231ccc12fb85aaaebf5bbc23edf3..a3a83acf84e286d1f5da9b5caffa256fc0ca3327 100644 --- a/test/test_modules/test_experiment_setup.py +++ b/test/test_modules/test_experiment_setup.py @@ -85,12 +85,19 @@ class TestExperimentSetup: # train parameters assert data_store.get("start", "general.train") == "1997-01-01" assert data_store.get("end", "general.train") == "2007-12-31" + assert data_store.get("min_length", "general.train") == 90 # validation parameters assert data_store.get("start", "general.val") == "2008-01-01" assert data_store.get("end", "general.val") == "2009-12-31" + assert data_store.get("min_length", "general.val") == 90 # test parameters assert data_store.get("start", "general.test") == "2010-01-01" assert data_store.get("end", "general.test") == "2017-12-31" + assert data_store.get("min_length", "general.test") == 90 + # train_val parameters + assert data_store.get("start", "general.train_val") == "1997-01-01" + assert data_store.get("end", "general.train_val") == "2009-12-31" + assert data_store.get("min_length", "general.train_val") == 180 # use all stations on all data sets (train, val, test) assert data_store.get("use_all_stations_on_all_data_sets", "general") is True @@ -104,7 +111,7 @@ class TestExperimentSetup: interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05", test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False, - fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True) + fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True, val_min_length=20) exp_setup = ExperimentSetup(**kwargs) data_store = exp_setup.data_store # experiment setup @@ -139,12 +146,19 @@ class TestExperimentSetup: # train parameters assert data_store.get("start", "general.train") == "2000-01-01" assert data_store.get("end", "general.train") == "2000-01-02" + assert data_store.get("min_length", "general.train") == 90 # validation parameters assert data_store.get("start", "general.val") == "2000-01-03" assert data_store.get("end", "general.val") == "2000-01-04" + assert data_store.get("min_length", "general.val") == 20 # test parameters assert data_store.get("start", "general.test") == "2000-01-05" assert data_store.get("end", "general.test") == "2000-01-06" + assert data_store.get("min_length", "general.test") == 90 + # train_val parameters + assert data_store.get("start", "general.train_val") == "2000-01-01" + assert data_store.get("end", "general.train_val") == "2000-01-04" + assert data_store.get("min_length", "general.train_val") == 110 # use all stations on all data sets (train, val, test) assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index d58cbd41e2ce4f25f4cd79127256e313b4aac649..c3f13e1ac1d7bdb0bdf17f81d3385472eaa46640 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -54,7 +54,7 @@ class TestPreProcessing: assert obj_with_exp_setup.data_store.search_name("generator") == [] obj_with_exp_setup.split_train_val_test() data_store = obj_with_exp_setup.data_store - expected_params = ["generator", "start", "end", "stations", "permute_data"] + expected_params = ["generator", "start", "end", "stations", "permute_data", "min_length"] assert data_store.search_scope("general.train") == sorted(expected_params) assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test", "general.train_val"])