From aa18ad42794e130e5c8db98a81f5c181cf7cc587 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 4 Nov 2019 16:31:44 +0100
Subject: [PATCH] added make_history_window and make_labels

---
 src/data_preparation.py       | 30 +++++++++++++++++++++++++-----
 test/test_data_preparation.py | 17 +++++++++++++++++
 2 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/src/data_preparation.py b/src/data_preparation.py
index e45b3e8b..74dc1363 100644
--- a/src/data_preparation.py
+++ b/src/data_preparation.py
@@ -184,10 +184,21 @@ class DataPrep(object):
         else:
             self.inverse_transform()
 
-    def make_history_window(self, dim, window):
-        raise NotImplementedError
+    def make_history_window(self, dim: str, window: int) -> None:
+        """
+        This function uses shifts the data window+1 times and returns a xarray which has a new dimension 'window'
+        containing the shifted data. This is used to represent history in the data. Results are stored in self.history .
+
+        :param dim: Dimension along shift will be applied
+        :param window: number of time steps to look back in history
+                Note: window will be treated as negative value. This should be in agreement with looking back on
+                a time line. Nonetheless positive values are allowed but they are converted to its negative
+                expression
+        """
+        window = -abs(window)
+        self.history = self.shift(dim, window)
 
-    def shift(self, dim: str, window: int):
+    def shift(self, dim: str, window: int) -> xr.DataArray:
         """
         This function uses xarray's shift function multiple times to represent history (if window <= 0)
         or lead time (if window > 0)
@@ -208,8 +219,17 @@ class DataPrep(object):
         res = xr.concat(res, dim=window_array)
         return res
 
-    def make_labels(self, dimension_name_of_target, target_variable, dimension_name_of_shift, window):
-        raise NotImplementedError
+    def make_labels(self, dim_name_of_target: str, target_var: str, dim_name_of_shift: str, window: int) -> None:
+        """
+        This function creates a xarray.DataArray containing labels
+
+        :param dim_name_of_target: Name of dimension which contains the target variable
+        :param target_var: Name of target variable in 'dimension'
+        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
+        :param window: lead time of label
+        """
+        window = abs(window)
+        self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})
 
     def history_label_nan_remove(self, dim: str) -> None:
         """
diff --git a/test/test_data_preparation.py b/test/test_data_preparation.py
index 4102b630..f8a18aea 100644
--- a/test/test_data_preparation.py
+++ b/test/test_data_preparation.py
@@ -189,3 +189,20 @@ class TestDataPrep:
         window, orig = self.extract_window_data(res, data.data, 0)
         assert list(res.data.shape) == [1] + list(data.data.shape)
         assert np.testing.assert_array_equal(orig, window) is None
+
+    def test_make_history_window(self, data):
+        assert data.history is None
+        data.make_history_window('datetime', 5)
+        assert data.history is not None
+        save_history = data.history
+        data.make_history_window('datetime', -5)
+        assert np.testing.assert_array_equal(data.history, save_history) is None
+
+    def test_make_labels(self, data):
+        assert data.label is None
+        data.make_labels('variables', 'o3', 'datetime', 3)
+        assert data.label.variables.data == 'o3'
+        assert list(data.label.shape) == [3] + list(data.data.shape)[:2]
+        save_label = data.label
+        data.make_labels('variables', 'o3', 'datetime', -3)
+        assert np.testing.assert_array_equal(data.label, save_label) is None
-- 
GitLab