From c62e38f4f83cb44f839338404aeb445b766af0b5 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 2 Mar 2020 11:58:27 +0100
Subject: [PATCH] introduced observation attribute, rename remove nan function,
 refac make functions for labels and history

---
 src/data_handling/data_generator.py           |  5 +-
 src/data_handling/data_preparation.py         | 29 ++++++--
 .../test_data_preparation.py                  | 74 +++++++++++++------
 3 files changed, 78 insertions(+), 30 deletions(-)

diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py
index 7aa24a88..3d42ea7c 100644
--- a/src/data_handling/data_generator.py
+++ b/src/data_handling/data_generator.py
@@ -113,9 +113,10 @@ class DataGenerator(keras.utils.Sequence):
                             **self.kwargs)
             data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill)
             data.transform("datetime", method=self.transform_method)
-            data.make_history_window(self.interpolate_dim, self.window_history_size)
+            data.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim)
             data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time)
-            data.history_label_nan_remove(self.interpolate_dim)
+            data.make_observation(self.target_dim, self.target_var, self.interpolate_dim)
+            data.remove_nan(self.interpolate_dim)
             self._save_pickle_data(data)
         return data
 
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index 490515aa..9a006beb 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -2,6 +2,7 @@ __author__ = 'Felix Kleinert, Lukas Leufen'
 __date__ = '2019-10-16'
 
 import datetime as dt
+from functools import reduce
 import logging
 import os
 from typing import Union, List, Iterable
@@ -15,6 +16,7 @@ from src import statistics
 
 # define a more general date type for type hinting
 date = Union[dt.date, dt.datetime]
+str_or_list = Union[str, List[str]]
 
 
 class DataPrep(object):
@@ -55,6 +57,7 @@ class DataPrep(object):
         self.std = None
         self.history = None
         self.label = None
+        self.observation = None
         self.kwargs = kwargs
         self.data = None
         self.meta = None
@@ -267,19 +270,20 @@ class DataPrep(object):
             std = None
         return mean, std, self._transform_method
 
-    def make_history_window(self, dim: str, window: int) -> None:
+    def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None:
         """
         This function uses shifts the data window+1 times and returns a xarray which has a new dimension 'window'
         containing the shifted data. This is used to represent history in the data. Results are stored in self.history .
 
-        :param dim: Dimension along shift will be applied
+        :param dim_name_of_inputs: Name of dimension which contains the input variables
         :param window: number of time steps to look back in history
                 Note: window will be treated as negative value. This should be in agreement with looking back on
                 a time line. Nonetheless positive values are allowed but they are converted to its negative
                 expression
+        :param dim_name_of_shift: Dimension along shift will be applied
         """
         window = -abs(window)
-        self.history = self.shift(dim, window)
+        self.history = self.shift(dim_name_of_shift, window).sel({dim_name_of_inputs: self.variables})
 
     def shift(self, dim: str, window: int) -> xr.DataArray:
         """
@@ -302,7 +306,7 @@ class DataPrep(object):
         res = xr.concat(res, dim=window_array)
         return res
 
-    def make_labels(self, dim_name_of_target: str, target_var: str, dim_name_of_shift: str, window: int) -> None:
+    def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, window: int) -> None:
         """
         This function creates a xarray.DataArray containing labels
 
@@ -314,7 +318,17 @@ class DataPrep(object):
         window = abs(window)
         self.label = self.shift(dim_name_of_shift, window).sel({dim_name_of_target: target_var})
 
-    def history_label_nan_remove(self, dim: str) -> None:
+    def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
+        """
+        This function creates a xarray.DataArray containing labels
+
+        :param dim_name_of_target: Name of dimension which contains the target variable
+        :param target_var: Name of target variable(s) in 'dimension'
+        :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
+        """
+        self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var})
+
+    def remove_nan(self, dim: str) -> None:
         """
         All NAs slices in dim which contain nans in self.history or self.label are removed in both data sets.
         This is done to present only a full matrix to keras.fit.
@@ -326,14 +340,17 @@ class DataPrep(object):
         if (self.history is not None) and (self.label is not None):
             non_nan_history = self.history.dropna(dim=dim)
             non_nan_label = self.label.dropna(dim=dim)
-            intersect = np.intersect1d(non_nan_history.coords[dim].values, non_nan_label.coords[dim].values)
+            non_nan_observation = self.observation.dropna(dim=dim)
+            intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values))
 
         if len(intersect) == 0:
             self.history = None
             self.label = None
+            self.observation = None
         else:
             self.history = self.history.sel({dim: intersect})
             self.label = self.label.sel({dim: intersect})
+            self.observation = self.observation.sel({dim: intersect})
 
     @staticmethod
     def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray:
diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py
index 72bacaf9..b38235b4 100644
--- a/test/test_data_handling/test_data_preparation.py
+++ b/test/test_data_handling/test_data_preparation.py
@@ -39,7 +39,7 @@ class TestDataPrep:
         assert data.variables == ['o3', 'temp']
         assert data.station_type == "background"
         assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'}
-        assert not all([data.mean, data.std, data.history, data.label, data.station_type])
+        assert not any([data.mean, data.std, data.history, data.label, data.observation])
         assert {'test': 'testKWARGS'}.items() <= data.kwargs.items()
 
     def test_init_no_stats(self):
@@ -221,29 +221,32 @@ class TestDataPrep:
         assert np.testing.assert_almost_equal(std, std_test) is None
         assert info == "standardise"
 
-    def test_nan_remove_no_hist_or_label(self, data):
-        assert data.history is None
-        assert data.label is None
-        data.history_label_nan_remove('datetime')
-        assert data.history is None
-        assert data.label is None
-        data.make_history_window('datetime', 6)
+    def test_remove_nan_no_hist_or_label(self, data):
+        assert not any([data.history, data.label, data.observation])
+        data.remove_nan('datetime')
+        assert not any([data.history, data.label, data.observation])
+        data.make_history_window('variables', 6, 'datetime')
         assert data.history is not None
-        data.history_label_nan_remove('datetime')
+        data.remove_nan('datetime')
         assert data.history is None
         data.make_labels('variables', 'o3', 'datetime', 2)
-        assert data.label is not None
-        data.history_label_nan_remove('datetime')
-        assert data.label is None
+        data.make_observation('variables', 'o3', 'datetime')
+        assert all(map(lambda x: x is not None, [data.label, data.observation]))
+        data.remove_nan('datetime')
+        assert not any([data.history, data.label, data.observation])
 
-    def test_nan_remove(self, data):
-        data.make_history_window('datetime', -12)
+    def test_remove_nan(self, data):
+        data.make_history_window('variables', -12, 'datetime')
         data.make_labels('variables', 'o3', 'datetime', 3)
+        data.make_observation('variables', 'o3', 'datetime')
         shape = data.history.shape
-        data.history_label_nan_remove('datetime')
+        data.remove_nan('datetime')
         assert data.history.isnull().sum() == 0
         assert itemgetter(0, 1, 3)(shape) == itemgetter(0, 1, 3)(data.history.shape)
         assert shape[2] >= data.history.shape[2]
+        remaining_len = data.history.datetime.shape
+        assert remaining_len == data.label.datetime.shape
+        assert remaining_len == data.observation.datetime.shape
 
     def test_create_index_array(self, data):
         index_array = data.create_index_array('window', range(1, 4))
@@ -273,34 +276,52 @@ class TestDataPrep:
         res = data.shift('datetime', 4)
         window, orig = self.extract_window_data(res, data.data, 4)
         assert res.coords.dims == ('window', 'Stations', 'datetime', 'variables')
-        assert list(res.data.shape) == [4] + list(data.data.shape)
+        assert list(res.data.shape) == [4, *data.data.shape]
         assert np.testing.assert_array_equal(orig, window) is None
         res = data.shift('datetime', -3)
         window, orig = self.extract_window_data(res, data.data, -3)
-        assert list(res.data.shape) == [4] + list(data.data.shape)
+        assert list(res.data.shape) == [4, *data.data.shape]
         assert np.testing.assert_array_equal(orig, window) is None
         res = data.shift('datetime', 0)
         window, orig = self.extract_window_data(res, data.data, 0)
-        assert list(res.data.shape) == [1] + list(data.data.shape)
+        assert list(res.data.shape) == [1, *data.data.shape]
         assert np.testing.assert_array_equal(orig, window) is None
 
     def test_make_history_window(self, data):
         assert data.history is None
-        data.make_history_window('datetime', 5)
+        data.make_history_window("variables", 5, "datetime")
         assert data.history is not None
         save_history = data.history
-        data.make_history_window('datetime', -5)
+        data.make_history_window("variables", -5, "datetime")
         assert np.testing.assert_array_equal(data.history, save_history) is None
 
     def test_make_labels(self, data):
         assert data.label is None
         data.make_labels('variables', 'o3', 'datetime', 3)
         assert data.label.variables.data == 'o3'
-        assert list(data.label.shape) == [3] + list(data.data.shape)[:2]
-        save_label = data.label
+        assert list(data.label.shape) == [3, *data.data.shape[:2]]
+        save_label = data.label.copy()
         data.make_labels('variables', 'o3', 'datetime', -3)
         assert np.testing.assert_array_equal(data.label, save_label) is None
 
+    def test_make_labels_multiple(self, data):
+        assert data.label is None
+        data.make_labels("variables", ["o3", "temp"], "datetime", 4)
+        assert all(data.label.variables.data == ["o3", "temp"])
+        assert list(data.label.shape) == [4, *data.data.shape[:2], 2]
+
+    def test_make_observation(self, data):
+        assert data.observation is None
+        data.make_observation("variables", "o3", "datetime")
+        assert data.observation.variables.data == "o3"
+        assert list(data.observation.shape) == [1, 1, data.data.datetime.shape[0]]
+
+    def test_make_observation_multiple(self, data):
+        assert data.observation is None
+        data.make_observation("variables", ["o3", "temp"], "datetime")
+        assert all(data.observation.variables.data == ["o3", "temp"])
+        assert list(data.observation.shape) == [1, 1, data.data.datetime.shape[0], 2]
+
     def test_slice(self, data):
         res = data._slice(data.data, dt.date(1997, 1, 1), dt.date(1997, 1, 10), 'datetime')
         assert itemgetter(0, 2)(res.shape) == itemgetter(0, 2)(data.data.shape)
@@ -326,3 +347,12 @@ class TestDataPrep:
             data_new = DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'dummy', 'DEBW107', ['o3', 'temp'],
                                 station_type='traffic', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
 
+    def test_get_transposed_history(self, data):
+        data.make_history_window("variables", 3, "datetime")
+        transposed = data.get_transposed_history()
+        assert transposed.coords.dims == ("datetime", "window", "Stations", "variables")
+
+    def test_get_transposed_label(self, data):
+        data.make_labels("variables", "o3", "datetime", 2)
+        transposed = data.get_transposed_label()
+        assert transposed.coords.dims == ("datetime", "window")
-- 
GitLab