From c62e38f4f83cb44f839338404aeb445b766af0b5 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Mon, 2 Mar 2020 11:58:27 +0100 Subject: [PATCH] introduced observation attribute, rename remove nan function, refac make functions for labels and history --- src/data_handling/data_generator.py | 5 +- src/data_handling/data_preparation.py | 29 ++++++-- .../test_data_preparation.py | 74 +++++++++++++------ 3 files changed, 78 insertions(+), 30 deletions(-) diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 7aa24a88..3d42ea7c 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -113,9 +113,10 @@ class DataGenerator(keras.utils.Sequence): **self.kwargs) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) data.transform("datetime", method=self.transform_method) - data.make_history_window(self.interpolate_dim, self.window_history_size) + data.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) - data.history_label_nan_remove(self.interpolate_dim) + data.make_observation(self.target_dim, self.target_var, self.interpolate_dim) + data.remove_nan(self.interpolate_dim) self._save_pickle_data(data) return data diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 490515aa..9a006beb 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -2,6 +2,7 @@ __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-10-16' import datetime as dt +from functools import reduce import logging import os from typing import Union, List, Iterable @@ -15,6 +16,7 @@ from src import statistics # define a more general date type for type hinting date = Union[dt.date, dt.datetime] +str_or_list = Union[str, List[str]] class DataPrep(object): @@ -55,6 +57,7 @@ class DataPrep(object): self.std = None self.history = None self.label = None + self.observation = None self.kwargs = kwargs self.data = None self.meta = None @@ -267,19 +270,20 @@ class DataPrep(object): std = None return mean, std, self._transform_method - def make_history_window(self, dim: str, window: int) -> None: + def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: """ This function uses shifts the data window+1 times and returns a xarray which has a new dimension 'window' containing the shifted data. This is used to represent history in the data. Results are stored in self.history . - :param dim: Dimension along shift will be applied + :param dim_name_of_inputs: Name of dimension which contains the input variables :param window: number of time steps to look back in history Note: window will be treated as negative value. This should be in agreement with looking back on a time line. Nonetheless positive values are allowed but they are converted to its negative expression + :param dim_name_of_shift: Dimension along shift will be applied """ window = -abs(window) - self.history = self.shift(dim, window) + self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables}) def shift(self, dim: str, window: int) -> xr.DataArray: """ @@ -302,7 +306,7 @@ class DataPrep(object): res = xr.concat(res, dim=window_array) return res - def make_labels(self, dim_name_of_target: str, target_var: str, dim_name_of_shift: str, window: int) -> None: + def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, window: int) -> None: """ This function creates a xarray.DataArray containing labels @@ -314,7 +318,17 @@ class DataPrep(object): window = abs(window) self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var}) - def history_label_nan_remove(self, dim: str) -> None: + def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: + """ + This function creates a xarray.DataArray containing labels + + :param dim_name_of_target: Name of dimension which contains the target variable + :param target_var: Name of target variable(s) in 'dimension' + :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied + """ + self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var}) + + def remove_nan(self, dim: str) -> None: """ All NAs slices in dim which contain nans in self.history or self.label are removed in both data sets. This is done to present only a full matrix to keras.fit. @@ -326,14 +340,17 @@ class DataPrep(object): if (self.history is not None) and (self.label is not None): non_nan_history = self.history.dropna(dim=dim) non_nan_label = self.label.dropna(dim=dim) - intersect = np.intersect1d(non_nan_history.coords[dim].values, non_nan_label.coords[dim].values) + non_nan_observation = self.observation.dropna(dim=dim) + intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values)) if len(intersect) == 0: self.history = None self.label = None + self.observation = None else: self.history = self.history.sel({dim: intersect}) self.label = self.label.sel({dim: intersect}) + self.observation = self.observation.sel({dim: intersect}) @staticmethod def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray: diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 72bacaf9..b38235b4 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -39,7 +39,7 @@ class TestDataPrep: assert data.variables == ['o3', 'temp'] assert data.station_type == "background" assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'} - assert not all([data.mean, data.std, data.history, data.label, data.station_type]) + assert not any([data.mean, data.std, data.history, data.label, data.observation]) assert {'test': 'testKWARGS'}.items() <= data.kwargs.items() def test_init_no_stats(self): @@ -221,29 +221,32 @@ class TestDataPrep: assert np.testing.assert_almost_equal(std, std_test) is None assert info == "standardise" - def test_nan_remove_no_hist_or_label(self, data): - assert data.history is None - assert data.label is None - data.history_label_nan_remove('datetime') - assert data.history is None - assert data.label is None - data.make_history_window('datetime', 6) + def test_remove_nan_no_hist_or_label(self, data): + assert not any([data.history, data.label, data.observation]) + data.remove_nan('datetime') + assert not any([data.history, data.label, data.observation]) + data.make_history_window('variables', 6, 'datetime') assert data.history is not None - data.history_label_nan_remove('datetime') + data.remove_nan('datetime') assert data.history is None data.make_labels('variables', 'o3', 'datetime', 2) - assert data.label is not None - data.history_label_nan_remove('datetime') - assert data.label is None + data.make_observation('variables', 'o3', 'datetime') + assert all(map(lambda x: x is not None, [data.label, data.observation])) + data.remove_nan('datetime') + assert not any([data.history, data.label, data.observation]) - def test_nan_remove(self, data): - data.make_history_window('datetime', -12) + def test_remove_nan(self, data): + data.make_history_window('variables', -12, 'datetime') data.make_labels('variables', 'o3', 'datetime', 3) + data.make_observation('variables', 'o3', 'datetime') shape = data.history.shape - data.history_label_nan_remove('datetime') + data.remove_nan('datetime') assert data.history.isnull().sum() == 0 assert itemgetter(0, 1, 3)(shape) == itemgetter(0, 1, 3)(data.history.shape) assert shape[2] >= data.history.shape[2] + remaining_len = data.history.datetime.shape + assert remaining_len == data.label.datetime.shape + assert remaining_len == data.observation.datetime.shape def test_create_index_array(self, data): index_array = data.create_index_array('window', range(1, 4)) @@ -273,34 +276,52 @@ class TestDataPrep: res = data.shift('datetime', 4) window, orig = self.extract_window_data(res, data.data, 4) assert res.coords.dims == ('window', 'Stations', 'datetime', 'variables') - assert list(res.data.shape) == [4] + list(data.data.shape) + assert list(res.data.shape) == [4, *data.data.shape] assert np.testing.assert_array_equal(orig, window) is None res = data.shift('datetime', -3) window, orig = self.extract_window_data(res, data.data, -3) - assert list(res.data.shape) == [4] + list(data.data.shape) + assert list(res.data.shape) == [4, *data.data.shape] assert np.testing.assert_array_equal(orig, window) is None res = data.shift('datetime', 0) window, orig = self.extract_window_data(res, data.data, 0) - assert list(res.data.shape) == [1] + list(data.data.shape) + assert list(res.data.shape) == [1, *data.data.shape] assert np.testing.assert_array_equal(orig, window) is None def test_make_history_window(self, data): assert data.history is None - data.make_history_window('datetime', 5) + data.make_history_window("variables", 5, "datetime") assert data.history is not None save_history = data.history - data.make_history_window('datetime', -5) + data.make_history_window("variables", -5, "datetime") assert np.testing.assert_array_equal(data.history, save_history) is None def test_make_labels(self, data): assert data.label is None data.make_labels('variables', 'o3', 'datetime', 3) assert data.label.variables.data == 'o3' - assert list(data.label.shape) == [3] + list(data.data.shape)[:2] - save_label = data.label + assert list(data.label.shape) == [3, *data.data.shape[:2]] + save_label = data.label.copy() data.make_labels('variables', 'o3', 'datetime', -3) assert np.testing.assert_array_equal(data.label, save_label) is None + def test_make_labels_multiple(self, data): + assert data.label is None + data.make_labels("variables", ["o3", "temp"], "datetime", 4) + assert all(data.label.variables.data == ["o3", "temp"]) + assert list(data.label.shape) == [4, *data.data.shape[:2], 2] + + def test_make_observation(self, data): + assert data.observation is None + data.make_observation("variables", "o3", "datetime") + assert data.observation.variables.data == "o3" + assert list(data.observation.shape) == [1, 1, data.data.datetime.shape[0]] + + def test_make_observation_multiple(self, data): + assert data.observation is None + data.make_observation("variables", ["o3", "temp"], "datetime") + assert all(data.observation.variables.data == ["o3", "temp"]) + assert list(data.observation.shape) == [1, 1, data.data.datetime.shape[0], 2] + def test_slice(self, data): res = data._slice(data.data, dt.date(1997, 1, 1), dt.date(1997, 1, 10), 'datetime') assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data.data.shape) @@ -326,3 +347,12 @@ class TestDataPrep: data_new = DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'dummy', 'DEBW107', ['o3', 'temp'], station_type='traffic', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + def test_get_transposed_history(self, data): + data.make_history_window("variables", 3, "datetime") + transposed = data.get_transposed_history() + assert transposed.coords.dims == ("datetime", "window", "Stations", "variables") + + def test_get_transposed_label(self, data): + data.make_labels("variables", "o3", "datetime", 2) + transposed = data.get_transposed_label() + assert transposed.coords.dims == ("datetime", "window") -- GitLab