From aa18ad42794e130e5c8db98a81f5c181cf7cc587 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Mon, 4 Nov 2019 16:31:44 +0100 Subject: [PATCH] added make_history_window and make_labels --- src/data_preparation.py | 30 +++++++++++++++++++++++++----- test/test_data_preparation.py | 17 +++++++++++++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/data_preparation.py b/src/data_preparation.py index e45b3e8b..74dc1363 100644 --- a/src/data_preparation.py +++ b/src/data_preparation.py @@ -184,10 +184,21 @@ class DataPrep(object): else: self.inverse_transform() - def make_history_window(self, dim, window): - raise NotImplementedError + def make_history_window(self, dim: str, window: int) -> None: + """ + This function uses shifts the data window+1 times and returns a xarray which has a new dimension 'window' + containing the shifted data. This is used to represent history in the data. Results are stored in self.history . + + :param dim: Dimension along shift will be applied + :param window: number of time steps to look back in history + Note: window will be treated as negative value. This should be in agreement with looking back on + a time line. Nonetheless positive values are allowed but they are converted to its negative + expression + """ + window = -abs(window) + self.history = self.shift(dim, window) - def shift(self, dim: str, window: int): + def shift(self, dim: str, window: int) -> xr.DataArray: """ This function uses xarray's shift function multiple times to represent history (if window <= 0) or lead time (if window > 0) @@ -208,8 +219,17 @@ class DataPrep(object): res = xr.concat(res, dim=window_array) return res - def make_labels(self, dimension_name_of_target, target_variable, dimension_name_of_shift, window): - raise NotImplementedError + def make_labels(self, dim_name_of_target: str, target_var: str, dim_name_of_shift: str, window: int) -> None: + """ + This function creates a xarray.DataArray containing labels + + :param dim_name_of_target: Name of dimension which contains the target variable + :param target_var: Name of target variable in 'dimension' + :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied + :param window: lead time of label + """ + window = abs(window) + self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var}) def history_label_nan_remove(self, dim: str) -> None: """ diff --git a/test/test_data_preparation.py b/test/test_data_preparation.py index 4102b630..f8a18aea 100644 --- a/test/test_data_preparation.py +++ b/test/test_data_preparation.py @@ -189,3 +189,20 @@ class TestDataPrep: window, orig = self.extract_window_data(res, data.data, 0) assert list(res.data.shape) == [1] + list(data.data.shape) assert np.testing.assert_array_equal(orig, window) is None + + def test_make_history_window(self, data): + assert data.history is None + data.make_history_window('datetime', 5) + assert data.history is not None + save_history = data.history + data.make_history_window('datetime', -5) + assert np.testing.assert_array_equal(data.history, save_history) is None + + def test_make_labels(self, data): + assert data.label is None + data.make_labels('variables', 'o3', 'datetime', 3) + assert data.label.variables.data == 'o3' + assert list(data.label.shape) == [3] + list(data.data.shape)[:2] + save_label = data.label + data.make_labels('variables', 'o3', 'datetime', -3) + assert np.testing.assert_array_equal(data.label, save_label) is None -- GitLab