diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 60e2da20dfb8d0c4bf9d6ae633a60591a27a9b2b..2889c5526267a35f190a61eb8453344a4ffc1cd2 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -186,7 +186,8 @@ class DataGenerator(keras.utils.Sequence): data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) data.make_history_window(self.interpolate_dim, self.window_history_size) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) - data.history_label_nan_remove(self.interpolate_dim) + data.make_observation(self.target_dim, self.target_var, self.interpolate_dim) + data.remove_nan(self.interpolate_dim) if save_local_tmp_storage: self._save_pickle_data(data) return data diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 594a4733f73e3d18025414f2d382d1840799abdb..3fae09306ab65d18f19d770b525cdc2296215bcd 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -2,6 +2,7 @@ __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-10-16' import datetime as dt +from functools import reduce import logging import os from typing import Union, List, Iterable @@ -15,6 +16,7 @@ from src import statistics # define a more general date type for type hinting date = Union[dt.date, dt.datetime] +str_or_list = Union[str, List[str]] class DataPrep(object): @@ -55,6 +57,7 @@ class DataPrep(object): self.std = None self.history = None self.label = None + self.observation = None self.kwargs = kwargs self.data = None self.meta = None @@ -135,10 +138,12 @@ class DataPrep(object): return xarr, meta def _set_file_name(self): - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(sorted(self.variables))}.nc") + all_vars = sorted(self.statistics_per_var.keys()) + return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}.nc") def _set_meta_file_name(self): - return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(sorted(self.variables))}_meta.csv") + all_vars = sorted(self.statistics_per_var.keys()) + return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv") def __repr__(self): return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \ @@ -275,19 +280,20 @@ class DataPrep(object): std = None return mean, std, self._transform_method - def make_history_window(self, dim: str, window: int) -> None: + def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: """ This function uses shifts the data window+1 times and returns a xarray which has a new dimension 'window' containing the shifted data. This is used to represent history in the data. Results are stored in self.history . - :param dim: Dimension along shift will be applied + :param dim_name_of_inputs: Name of dimension which contains the input variables :param window: number of time steps to look back in history Note: window will be treated as negative value. This should be in agreement with looking back on a time line. Nonetheless positive values are allowed but they are converted to its negative expression + :param dim_name_of_shift: Dimension along shift will be applied """ window = -abs(window) - self.history = self.shift(dim, window) + self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables}) def shift(self, dim: str, window: int) -> xr.DataArray: """ @@ -310,7 +316,7 @@ class DataPrep(object): res = xr.concat(res, dim=window_array) return res - def make_labels(self, dim_name_of_target: str, target_var: str, dim_name_of_shift: str, window: int) -> None: + def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, window: int) -> None: """ This function creates a xarray.DataArray containing labels @@ -322,7 +328,17 @@ class DataPrep(object): window = abs(window) self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var}) - def history_label_nan_remove(self, dim: str) -> None: + def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: + """ + This function creates a xarray.DataArray containing labels + + :param dim_name_of_target: Name of dimension which contains the target variable + :param target_var: Name of target variable(s) in 'dimension' + :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied + """ + self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var}) + + def remove_nan(self, dim: str) -> None: """ All NAs slices in dim which contain nans in self.history or self.label are removed in both data sets. This is done to present only a full matrix to keras.fit. @@ -334,14 +350,17 @@ class DataPrep(object): if (self.history is not None) and (self.label is not None): non_nan_history = self.history.dropna(dim=dim) non_nan_label = self.label.dropna(dim=dim) - intersect = np.intersect1d(non_nan_history.coords[dim].values, non_nan_label.coords[dim].values) + non_nan_observation = self.observation.dropna(dim=dim) + intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values)) if len(intersect) == 0: self.history = None self.label = None + self.observation = None else: self.history = self.history.sel({dim: intersect}) self.label = self.label.sel({dim: intersect}) + self.observation = self.observation.sel({dim: intersect}) @staticmethod def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray: diff --git a/src/helpers.py b/src/helpers.py index 621974bfdeae46e87ff62990b68ea326eb38c033..073a7bbf9ae3b7041591d48e4e5b7f3ef0efae42 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -213,3 +213,8 @@ def list_pop(list_full: list, pop_items): list_pop = list_full.copy() list_pop.remove(pop_items[0]) return list_pop + + +def dict_pop(dict_orig: Dict, pop_keys): + pop_keys = to_list(pop_keys) + return {k: v for k, v in dict_orig.items() if k not in pop_keys} diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 49e586d404c88e177946149542517140db1c6ff9..56c22a81e48421438816855770b7477e84e3a8d8 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -28,7 +28,7 @@ class ExperimentSetup(RunEnvironment): trainable: Train new model if true, otherwise try to load existing model """ - def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, station_type=None, variables=None, + def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None, statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None, window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, @@ -70,12 +70,11 @@ class ExperimentSetup(RunEnvironment): helpers.check_path_and_create(self.data_store.get("forecast_path", "general")) # setup for data - self._set_param("var_all_dict", var_all_dict, default=DEFAULT_VAR_ALL_DICT) self._set_param("stations", stations, default=DEFAULT_STATIONS) self._set_param("network", network, default="AIRBASE") self._set_param("station_type", station_type, default=None) - self._set_param("variables", variables, default=list(self.data_store.get("var_all_dict", "general").keys())) - self._set_param("statistics_per_var", statistics_per_var, default=self.data_store.get("var_all_dict", "general")) + self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) + self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var", "general").keys())) self._compare_variables_and_statistics() self._set_param("start", start, default="1997-01-01", scope="general") self._set_param("end", end, default="2017-12-31", scope="general") @@ -87,6 +86,7 @@ class ExperimentSetup(RunEnvironment): # target self._set_param("target_var", target_var, default="o3") + self._check_target_var() self._set_param("target_dim", target_dim, default='variables') self._set_param("window_lead_time", window_lead_time, default=3) @@ -136,16 +136,27 @@ class ExperimentSetup(RunEnvironment): return {} def _compare_variables_and_statistics(self): - logging.debug("check if all variables are included in statistics_per_var") - var = self.data_store.get("variables", "general") stat = self.data_store.get("statistics_per_var", "general") + var = self.data_store.get("variables", "general") if not set(var).issubset(stat.keys()): missing = set(var).difference(stat.keys()) raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested " f"variables are part of statistics_per_var. Please add also information on the missing " f"statistics for the variables: {missing}") + def _check_target_var(self): + target_var = helpers.to_list(self.data_store.get("target_var", "general")) + stat = self.data_store.get("statistics_per_var", "general") + var = self.data_store.get("variables", "general") + if not set(target_var).issubset(stat.keys()): + raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") + unused_vars = set(stat.keys()).difference(set(var).union(target_var)) + if len(unused_vars) > 0: + logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}") + stat_new = helpers.dict_pop(stat, list(unused_vars)) + self._set_param("statistics_per_var", stat_new) + if __name__ == "__main__": diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 5c392a402da47251c51668e0b06a3067104a61e6..962c9f52065729381ce11e8a8adcbeed45a4c011 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -168,7 +168,7 @@ class PostProcessing(RunEnvironment): nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised) # persistence - persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std, + persistence_prediction = self._create_persistence_forecast(data, persistence_prediction, mean, std, transformation_method, normalised) # ols @@ -197,7 +197,7 @@ class PostProcessing(RunEnvironment): @staticmethod def _create_observation(data, _, mean, std, transformation_method, normalised): - obs = data.label.copy() + obs = data.observation.copy() if not normalised: obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method) return obs @@ -211,8 +211,8 @@ class PostProcessing(RunEnvironment): ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols return ols_prediction - def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method, normalised): - tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var}) + def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised): + tmp_persi = data.observation.copy().sel({'window': 0}) if not normalised: tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) window_lead_time = self.data_store.get("window_lead_time", "general") @@ -295,7 +295,7 @@ class PostProcessing(RunEnvironment): try: data = self.train_val_data.get_data_generator(station) mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) - external_data = self._create_observation(data, None, mean, std, transformation_method) + external_data = self._create_observation(data, None, mean, std, transformation_method, normalised=False) external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"]) return external_data.rename({'datetime': 'index'}) except KeyError: diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index ac449c4dc6d4c83a457eccc93a766ec4f17f58c9..53f80ce5cfe248ede1127b03956d58bb7f70a783 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -39,7 +39,7 @@ class TestDataPrep: assert data.variables == ['o3', 'temp'] assert data.station_type == "background" assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'} - assert not all([data.mean, data.std, data.history, data.label, data.station_type]) + assert not any([data.mean, data.std, data.history, data.label, data.observation]) assert {'test': 'testKWARGS'}.items() <= data.kwargs.items() def test_init_no_stats(self): @@ -258,29 +258,32 @@ class TestDataPrep: assert np.testing.assert_almost_equal(std, std_test) is None assert info == "standardise" - def test_nan_remove_no_hist_or_label(self, data): - assert data.history is None - assert data.label is None - data.history_label_nan_remove('datetime') - assert data.history is None - assert data.label is None - data.make_history_window('datetime', 6) + def test_remove_nan_no_hist_or_label(self, data): + assert not any([data.history, data.label, data.observation]) + data.remove_nan('datetime') + assert not any([data.history, data.label, data.observation]) + data.make_history_window('variables', 6, 'datetime') assert data.history is not None - data.history_label_nan_remove('datetime') + data.remove_nan('datetime') assert data.history is None data.make_labels('variables', 'o3', 'datetime', 2) - assert data.label is not None - data.history_label_nan_remove('datetime') - assert data.label is None + data.make_observation('variables', 'o3', 'datetime') + assert all(map(lambda x: x is not None, [data.label, data.observation])) + data.remove_nan('datetime') + assert not any([data.history, data.label, data.observation]) - def test_nan_remove(self, data): - data.make_history_window('datetime', -12) + def test_remove_nan(self, data): + data.make_history_window('variables', -12, 'datetime') data.make_labels('variables', 'o3', 'datetime', 3) + data.make_observation('variables', 'o3', 'datetime') shape = data.history.shape - data.history_label_nan_remove('datetime') + data.remove_nan('datetime') assert data.history.isnull().sum() == 0 assert itemgetter(0, 1, 3)(shape) == itemgetter(0, 1, 3)(data.history.shape) assert shape[2] >= data.history.shape[2] + remaining_len = data.history.datetime.shape + assert remaining_len == data.label.datetime.shape + assert remaining_len == data.observation.datetime.shape def test_create_index_array(self, data): index_array = data.create_index_array('window', range(1, 4)) @@ -310,34 +313,52 @@ class TestDataPrep: res = data.shift('datetime', 4) window, orig = self.extract_window_data(res, data.data, 4) assert res.coords.dims == ('window', 'Stations', 'datetime', 'variables') - assert list(res.data.shape) == [4] + list(data.data.shape) + assert list(res.data.shape) == [4, *data.data.shape] assert np.testing.assert_array_equal(orig, window) is None res = data.shift('datetime', -3) window, orig = self.extract_window_data(res, data.data, -3) - assert list(res.data.shape) == [4] + list(data.data.shape) + assert list(res.data.shape) == [4, *data.data.shape] assert np.testing.assert_array_equal(orig, window) is None res = data.shift('datetime', 0) window, orig = self.extract_window_data(res, data.data, 0) - assert list(res.data.shape) == [1] + list(data.data.shape) + assert list(res.data.shape) == [1, *data.data.shape] assert np.testing.assert_array_equal(orig, window) is None def test_make_history_window(self, data): assert data.history is None - data.make_history_window('datetime', 5) + data.make_history_window("variables", 5, "datetime") assert data.history is not None save_history = data.history - data.make_history_window('datetime', -5) + data.make_history_window("variables", -5, "datetime") assert np.testing.assert_array_equal(data.history, save_history) is None def test_make_labels(self, data): assert data.label is None data.make_labels('variables', 'o3', 'datetime', 3) assert data.label.variables.data == 'o3' - assert list(data.label.shape) == [3] + list(data.data.shape)[:2] - save_label = data.label + assert list(data.label.shape) == [3, *data.data.shape[:2]] + save_label = data.label.copy() data.make_labels('variables', 'o3', 'datetime', -3) assert np.testing.assert_array_equal(data.label, save_label) is None + def test_make_labels_multiple(self, data): + assert data.label is None + data.make_labels("variables", ["o3", "temp"], "datetime", 4) + assert all(data.label.variables.data == ["o3", "temp"]) + assert list(data.label.shape) == [4, *data.data.shape[:2], 2] + + def test_make_observation(self, data): + assert data.observation is None + data.make_observation("variables", "o3", "datetime") + assert data.observation.variables.data == "o3" + assert list(data.observation.shape) == [1, 1, data.data.datetime.shape[0]] + + def test_make_observation_multiple(self, data): + assert data.observation is None + data.make_observation("variables", ["o3", "temp"], "datetime") + assert all(data.observation.variables.data == ["o3", "temp"]) + assert list(data.observation.shape) == [1, 1, data.data.datetime.shape[0], 2] + def test_slice(self, data): res = data._slice(data.data, dt.date(1997, 1, 1), dt.date(1997, 1, 10), 'datetime') assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data.data.shape) @@ -363,3 +384,12 @@ class TestDataPrep: data_new = DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'dummy', 'DEBW107', ['o3', 'temp'], station_type='traffic', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + def test_get_transposed_history(self, data): + data.make_history_window("variables", 3, "datetime") + transposed = data.get_transposed_history() + assert transposed.coords.dims == ("datetime", "window", "Stations", "variables") + + def test_get_transposed_label(self, data): + data.make_labels("variables", "o3", "datetime", 2) + transposed = data.get_transposed_label() + assert transposed.coords.dims == ("datetime", "window")