From ec34a136d2c4cd0dfc79001f82bbce59ddd30b10 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Wed, 28 Apr 2021 12:38:29 +0200 Subject: [PATCH] new class DataHandlerMixedSamplingWithFilterSingleStation that bundles common methods of the kz and fir filter when used as mixed sampling --- .../data_handler_mixed_sampling.py | 65 ++++++++++++++----- .../data_handler/data_handler_with_filter.py | 18 ++++- .../test_data_handler_mixed_sampling.py | 16 ++--- 3 files changed, 72 insertions(+), 27 deletions(-) diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 4c84866b..71f9fe73 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -2,7 +2,8 @@ __author__ = 'Lukas Leufen' __date__ = '2020-11-05' from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation -from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation +from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation, \ + DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation from mlair.data_handler import DefaultDataHandler from mlair import helpers from mlair.helpers import remove_items @@ -94,8 +95,8 @@ class DataHandlerMixedSampling(DefaultDataHandler): class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation, - DataHandlerKzFilterSingleStation): - _requirements1 = DataHandlerKzFilterSingleStation.requirements() + DataHandlerFilterSingleStation): + _requirements1 = DataHandlerFilterSingleStation.requirements() _requirements2 = DataHandlerMixedSamplingSingleStation.requirements() _requirements = list(set(_requirements1 + _requirements2)) @@ -107,19 +108,16 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def make_input_target(self): """ - A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values + A FIR filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values with daily resolution. """ self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data self.set_inputs_and_targets() - self.apply_kz_filter() + self.apply_filter() def estimate_filter_width(self): - """ - f = 0.5 / (len * sqrt(itr)) -> T = 1 / f - :return: - """ - return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) + """Return maximum filter width.""" + raise NotImplementedError @staticmethod def _add_time_delta(date, delta): @@ -156,22 +154,55 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi return data def _extract_lazy(self, lazy_data): - _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data + _data, self.meta, _input_data, _target_data = lazy_data start_inp, end_inp = self.update_start_end(0) self._data = tuple(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1])) self.input_data = self._slice_prep(_input_data, start_inp, end_inp) self.target_data = self._slice_prep(_target_data, self.start, self.end) -class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): +class DataHandlerMixedSamplingWithKzFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, + DataHandlerKzFilterSingleStation): + _requirements1 = DataHandlerKzFilterSingleStation.requirements() + _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + _requirements = list(set(_requirements1 + _requirements2)) + + def estimate_filter_width(self): + """ + f = 0.5 / (len * sqrt(itr)) -> T = 1 / f + :return: + """ + return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data + super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + + +class DataHandlerMixedSamplingWithKzFilter(DefaultDataHandler): """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" - data_handler = DataHandlerMixedSamplingWithFilterSingleStation - data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation + data_handler = DataHandlerMixedSamplingWithKzFilterSingleStation + data_handler_transformation = DataHandlerMixedSamplingWithKzFilterSingleStation _requirements = data_handler.requirements() -class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation): +class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, + DataHandlerFirFilterSingleStation): + _requirements1 = DataHandlerFirFilterSingleStation.requirements() + _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() + _requirements = list(set(_requirements1 + _requirements2)) + + def estimate_filter_width(self): + """ """ + return 5 # Todo: adjust this method + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.fir_coeff = lazy_data + super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + + +class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation): """ Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). @@ -181,8 +212,8 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil """ - _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _hash = DataHandlerMixedSamplingWithFilterSingleStation._hash + ["time_delta"] + _requirements = DataHandlerMixedSamplingWithKzFilterSingleStation.requirements() + _hash = DataHandlerMixedSamplingWithKzFilterSingleStation._hash + ["time_delta"] def __init__(self, *args, time_delta=np.sqrt, **kwargs): assert isinstance(time_delta, Callable) diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 0757e528..740642fe 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -96,10 +96,10 @@ class DataHandlerFilterSingleStation(DataHandlerSingleStation): self.filter_dim).copy() def _create_lazy_data(self): - return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days] + raise NotImplementedError def _extract_lazy(self, lazy_data): - _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data + _data, self.meta, _input_data, _target_data = lazy_data f_prep = partial(self._slice_prep, start=self.start, end=self.end) self._data, self.input_data, self.target_data = list(map(f_prep, [_data, _input_data, _target_data])) @@ -181,6 +181,13 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): index.append("unfiltered") return pd.Index(index, name=self.filter_dim) + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data, self.fir_coeff] + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.fir_coeff = lazy_data + super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + class DataHandlerFirFilter(DefaultDataHandler): """Data handler using FIR filtered data.""" @@ -233,6 +240,13 @@ class DataHandlerKzFilterSingleStation(DataHandlerFilterSingleStation): index = list(map(lambda x: str(x) + "d", index)) + ["res"] return pd.Index(index, name=self.filter_dim) + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days] + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data + super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + class DataHandlerKzFilter(DefaultDataHandler): """Data handler using kz filtered data.""" diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py index 19899a77..56751c44 100644 --- a/test/test_data_handler/test_data_handler_mixed_sampling.py +++ b/test/test_data_handler/test_data_handler_mixed_sampling.py @@ -2,8 +2,8 @@ __author__ = 'Lukas Leufen' __date__ = '2020-12-10' from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \ - DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilter, \ - DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerSeparationOfScales, \ + DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithKzFilter, \ + DataHandlerMixedSamplingWithKzFilterSingleStation, DataHandlerSeparationOfScales, \ DataHandlerSeparationOfScalesSingleStation from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation @@ -89,15 +89,15 @@ class TestDataHandlerMixedSamplingSingleStation: class TestDataHandlerMixedSamplingWithFilter: def test_data_handler(self): - obj = object.__new__(DataHandlerMixedSamplingWithFilter) - assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__ + obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) + assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__ def test_data_handler_transformation(self): - obj = object.__new__(DataHandlerMixedSamplingWithFilter) - assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithFilterSingleStation.__qualname__ + obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) + assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__ def test_requirements(self): - obj = object.__new__(DataHandlerMixedSamplingWithFilter) + obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) req1 = object.__new__(DataHandlerMixedSamplingSingleStation) req2 = object.__new__(DataHandlerKzFilterSingleStation) req = list(set(req1.requirements() + req2.requirements())) @@ -119,7 +119,7 @@ class TestDataHandlerSeparationOfScales: assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ def test_requirements(self): - obj = object.__new__(DataHandlerMixedSamplingWithFilter) + obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) req1 = object.__new__(DataHandlerMixedSamplingSingleStation) req2 = object.__new__(DataHandlerKzFilterSingleStation) req = list(set(req1.requirements() + req2.requirements())) -- GitLab