From ec34a136d2c4cd0dfc79001f82bbce59ddd30b10 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 28 Apr 2021 12:38:29 +0200
Subject: [PATCH] new class DataHandlerMixedSamplingWithFilterSingleStation
 that bundles common methods of the kz and fir filter when used as mixed
 sampling

---
 .../data_handler_mixed_sampling.py            | 65 ++++++++++++++-----
 .../data_handler/data_handler_with_filter.py  | 18 ++++-
 .../test_data_handler_mixed_sampling.py       | 16 ++---
 3 files changed, 72 insertions(+), 27 deletions(-)

diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index 4c84866b..71f9fe73 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -2,7 +2,8 @@ __author__ = 'Lukas Leufen'
 __date__ = '2020-11-05'
 
 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
-from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation
+from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation, \
+    DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation
 from mlair.data_handler import DefaultDataHandler
 from mlair import helpers
 from mlair.helpers import remove_items
@@ -94,8 +95,8 @@ class DataHandlerMixedSampling(DefaultDataHandler):
 
 
 class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation,
-                                                      DataHandlerKzFilterSingleStation):
-    _requirements1 = DataHandlerKzFilterSingleStation.requirements()
+                                                      DataHandlerFilterSingleStation):
+    _requirements1 = DataHandlerFilterSingleStation.requirements()
     _requirements2 = DataHandlerMixedSamplingSingleStation.requirements()
     _requirements = list(set(_requirements1 + _requirements2))
 
@@ -107,19 +108,16 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
 
     def make_input_target(self):
         """
-        A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
+        A FIR filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
         with daily resolution.
         """
         self._data = tuple(map(self.load_and_interpolate, [0, 1]))  # load input (0) and target (1) data
         self.set_inputs_and_targets()
-        self.apply_kz_filter()
+        self.apply_filter()
 
     def estimate_filter_width(self):
-        """
-        f = 0.5 / (len * sqrt(itr)) -> T = 1 / f
-        :return:
-        """
-        return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2)
+        """Return maximum filter width."""
+        raise NotImplementedError
 
     @staticmethod
     def _add_time_delta(date, delta):
@@ -156,22 +154,55 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
         return data
 
     def _extract_lazy(self, lazy_data):
-        _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
+        _data, self.meta, _input_data, _target_data = lazy_data
         start_inp, end_inp = self.update_start_end(0)
         self._data = tuple(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1]))
         self.input_data = self._slice_prep(_input_data, start_inp, end_inp)
         self.target_data = self._slice_prep(_target_data, self.start, self.end)
 
 
-class DataHandlerMixedSamplingWithFilter(DefaultDataHandler):
+class DataHandlerMixedSamplingWithKzFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation,
+                                                        DataHandlerKzFilterSingleStation):
+    _requirements1 = DataHandlerKzFilterSingleStation.requirements()
+    _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
+    _requirements = list(set(_requirements1 + _requirements2))
+
+    def estimate_filter_width(self):
+        """
+        f = 0.5 / (len * sqrt(itr)) -> T = 1 / f
+        :return:
+        """
+        return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2)
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
+        super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data))
+
+
+class DataHandlerMixedSamplingWithKzFilter(DefaultDataHandler):
     """Data handler using mixed sampling for input and target. Inputs are temporal filtered."""
 
-    data_handler = DataHandlerMixedSamplingWithFilterSingleStation
-    data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation
+    data_handler = DataHandlerMixedSamplingWithKzFilterSingleStation
+    data_handler_transformation = DataHandlerMixedSamplingWithKzFilterSingleStation
     _requirements = data_handler.requirements()
 
 
-class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation):
+class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation,
+                                                         DataHandlerFirFilterSingleStation):
+    _requirements1 = DataHandlerFirFilterSingleStation.requirements()
+    _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
+    _requirements = list(set(_requirements1 + _requirements2))
+
+    def estimate_filter_width(self):
+        """ """
+        return 5  # Todo: adjust this method
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.fir_coeff = lazy_data
+        super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data))
+
+
+class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation):
     """
     Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the
     separation frequency of a filtered time series the time step delta for input data is adjusted (see image below).
@@ -181,8 +212,8 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
 
     """
 
-    _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements()
-    _hash = DataHandlerMixedSamplingWithFilterSingleStation._hash + ["time_delta"]
+    _requirements = DataHandlerMixedSamplingWithKzFilterSingleStation.requirements()
+    _hash = DataHandlerMixedSamplingWithKzFilterSingleStation._hash + ["time_delta"]
 
     def __init__(self, *args, time_delta=np.sqrt, **kwargs):
         assert isinstance(time_delta, Callable)
diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py
index 0757e528..740642fe 100644
--- a/mlair/data_handler/data_handler_with_filter.py
+++ b/mlair/data_handler/data_handler_with_filter.py
@@ -96,10 +96,10 @@ class DataHandlerFilterSingleStation(DataHandlerSingleStation):
                                       self.filter_dim).copy()
 
     def _create_lazy_data(self):
-        return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days]
+        raise NotImplementedError
 
     def _extract_lazy(self, lazy_data):
-        _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
+        _data, self.meta, _input_data, _target_data = lazy_data
         f_prep = partial(self._slice_prep, start=self.start, end=self.end)
         self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data]))
 
@@ -181,6 +181,13 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation):
             index.append("unfiltered")
         return pd.Index(index, name=self.filter_dim)
 
+    def _create_lazy_data(self):
+        return [self._data, self.meta, self.input_data, self.target_data, self.fir_coeff]
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.fir_coeff = lazy_data
+        super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data))
+
 
 class DataHandlerFirFilter(DefaultDataHandler):
     """Data handler using FIR filtered data."""
@@ -233,6 +240,13 @@ class DataHandlerKzFilterSingleStation(DataHandlerFilterSingleStation):
         index = list(map(lambda x: str(x) + "d", index)) + ["res"]
         return pd.Index(index, name=self.filter_dim)
 
+    def _create_lazy_data(self):
+        return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days]
+
+    def _extract_lazy(self, lazy_data):
+        _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
+        super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data))
+
 
 class DataHandlerKzFilter(DefaultDataHandler):
     """Data handler using kz filtered data."""
diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py
index 19899a77..56751c44 100644
--- a/test/test_data_handler/test_data_handler_mixed_sampling.py
+++ b/test/test_data_handler/test_data_handler_mixed_sampling.py
@@ -2,8 +2,8 @@ __author__ = 'Lukas Leufen'
 __date__ = '2020-12-10'
 
 from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \
-    DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilter, \
-    DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerSeparationOfScales, \
+    DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithKzFilter, \
+    DataHandlerMixedSamplingWithKzFilterSingleStation, DataHandlerSeparationOfScales, \
     DataHandlerSeparationOfScalesSingleStation
 from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation
 from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation
@@ -89,15 +89,15 @@ class TestDataHandlerMixedSamplingSingleStation:
 class TestDataHandlerMixedSamplingWithFilter:
 
     def test_data_handler(self):
-        obj = object.__new__(DataHandlerMixedSamplingWithFilter)
-        assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__
+        obj = object.__new__(DataHandlerMixedSamplingWithKzFilter)
+        assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__
 
     def test_data_handler_transformation(self):
-        obj = object.__new__(DataHandlerMixedSamplingWithFilter)
-        assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__
+        obj = object.__new__(DataHandlerMixedSamplingWithKzFilter)
+        assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__
 
     def test_requirements(self):
-        obj = object.__new__(DataHandlerMixedSamplingWithFilter)
+        obj = object.__new__(DataHandlerMixedSamplingWithKzFilter)
         req1 = object.__new__(DataHandlerMixedSamplingSingleStation)
         req2 = object.__new__(DataHandlerKzFilterSingleStation)
         req = list(set(req1.requirements() + req2.requirements()))
@@ -119,7 +119,7 @@ class TestDataHandlerSeparationOfScales:
         assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__
 
     def test_requirements(self):
-        obj = object.__new__(DataHandlerMixedSamplingWithFilter)
+        obj = object.__new__(DataHandlerMixedSamplingWithKzFilter)
         req1 = object.__new__(DataHandlerMixedSamplingSingleStation)
         req2 = object.__new__(DataHandlerKzFilterSingleStation)
         req = list(set(req1.requirements() + req2.requirements()))
-- 
GitLab