diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 2889c5526267a35f190a61eb8453344a4ffc1cd2..24c9ada65b4bfd71de12785b2714cc5de94dc21f 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -184,7 +184,7 @@ class DataGenerator(keras.utils.Sequence): if self.transformation is not None: data.transform("datetime", **helpers.dict_pop(self.transformation, "scope")) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.make_history_window(self.interpolate_dim, self.window_history_size) + data.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.make_observation(self.target_dim, self.target_var, self.interpolate_dim) data.remove_nan(self.interpolate_dim) diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 53f80ce5cfe248ede1127b03956d58bb7f70a783..91719f3dd16326ee6281c4db8ef3aa87e238d70f 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -27,6 +27,7 @@ class TestDataPrep: d.network = 'UBA' d.station = ['DEBW107'] d.variables = ['o3', 'temp'] + d.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} d.station_type = "background" d.sampling = "daily" d.kwargs = None @@ -125,6 +126,7 @@ class TestDataPrep: d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "data") d.station = 'TESTSTATION' d.variables = ['a', 'bc'] + d.statistics_per_var = {'a': 'dma8eu', 'bc': 'maximum'} assert d._set_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)), "data/TESTSTATION_a_bc.nc") assert d._set_meta_file_name() == os.path.join(os.path.abspath(os.path.dirname(__file__)), diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py index 9e6d17627d1697a2150ea7f74a373a720d2f02ac..894e4b552af4231ccc12fb85aaaebf5bbc23edf3 100644 --- a/test/test_modules/test_experiment_setup.py +++ b/test/test_modules/test_experiment_setup.py @@ -54,11 +54,10 @@ class TestExperimentSetup: assert data_store.get("experiment_name", "general") == "TestExperiment" path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment")) assert data_store.get("experiment_path", "general") == path - default_var_all_dict = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', - 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', - 'pblheight': 'maximum'} + default_statistics_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', + 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', + 'cloudcover': 'average_values', 'pblheight': 'maximum'} # setup for data - assert data_store.get("var_all_dict", "general") == default_var_all_dict default_stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', 'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBW038', 'DEBW081', 'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072', @@ -69,8 +68,8 @@ class TestExperimentSetup: assert data_store.get("stations", "general") == default_stations assert data_store.get("network", "general") == "AIRBASE" assert data_store.get("station_type", "general") is None - assert data_store.get("variables", "general") == list(default_var_all_dict.keys()) - assert data_store.get("statistics_per_var", "general") == default_var_all_dict + assert data_store.get("variables", "general") == list(default_statistics_per_var.keys()) + assert data_store.get("statistics_per_var", "general") == default_statistics_per_var assert data_store.get("start", "general") == "1997-01-01" assert data_store.get("end", "general") == "2017-12-31" assert data_store.get("window_history_size", "general") == 13 @@ -98,11 +97,10 @@ class TestExperimentSetup: def test_init_no_default(self): experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) kwargs = dict(parser_args={"experiment_date": "TODAY"}, - var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'}, + statistics_per_var={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'}, stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", station_type="background", - variables=["o3", "temp"], - statistics_per_var=None, start="1999-01-01", end="2001-01-01", window_history_size=4, - target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1", + variables=["o3", "temp"], start="1999-01-01", end="2001-01-01", window_history_size=4, + target_var="relhum", target_dim="target", window_lead_time=10, dimensions="dim1", interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05", test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False, @@ -120,8 +118,6 @@ class TestExperimentSetup: "TODAY_network")) assert data_store.get("experiment_path", "general") == path # setup for data - assert data_store.get("var_all_dict", "general") == {'o3': 'dma8eu', 'relhum': 'average_values', - 'temp': 'maximum'} assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027'] assert data_store.get("network", "general") == "INTERNET" assert data_store.get("station_type", "general") == "background" @@ -132,7 +128,7 @@ class TestExperimentSetup: assert data_store.get("end", "general") == "2001-01-01" assert data_store.get("window_history_size", "general") == 4 # target - assert data_store.get("target_var", "general") == "temp" + assert data_store.get("target_var", "general") == "relhum" assert data_store.get("target_dim", "general") == "target" assert data_store.get("window_lead_time", "general") == 10 # interpolation @@ -173,8 +169,8 @@ class TestExperimentSetup: def test_compare_variables_and_statistics(self): experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) kwargs = dict(parser_args={"experiment_date": "TODAY"}, - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, - stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], statistics_per_var=None, + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, + stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], experiment_path=experiment_path) with pytest.raises(ValueError) as e: ExperimentSetup(**kwargs) diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index 29172a1b8500b605859e925574535c6158c7d805..d58cbd41e2ce4f25f4cd79127256e313b4aac649 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -27,7 +27,7 @@ class TestPreProcessing: @pytest.fixture def obj_with_exp_setup(self): ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background") + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background") pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre @@ -35,7 +35,7 @@ class TestPreProcessing: def test_init(self, caplog): ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) caplog.set_level(logging.INFO) with PreProcessing(): assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')