diff --git a/docs/_source/_plots/separation_of_scales.png b/docs/_source/_plots/separation_of_scales.png new file mode 100755 index 0000000000000000000000000000000000000000..d2bbc625a5d50051d8ec2babe976f88d7446e39e Binary files /dev/null and b/docs/_source/_plots/separation_of_scales.png differ diff --git a/mlair/data_handler/data_handler_kz_filter.py b/mlair/data_handler/data_handler_kz_filter.py index de1cb071369395edd9a8b6e869d65561dbfa0f11..6b960e79a14813c18f56a24642e78901bf687aad 100644 --- a/mlair/data_handler/data_handler_kz_filter.py +++ b/mlair/data_handler/data_handler_kz_filter.py @@ -25,11 +25,9 @@ class DataHandlerKzFilterSingleStation(DataHandlerSingleStation): def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): self._check_sampling(**kwargs) - kz_filter_length = to_list(kz_filter_length) - kz_filter_iter = to_list(kz_filter_iter) # self.original_data = None # ToDo: implement here something to store unfiltered data - self.kz_filter_length = kz_filter_length - self.kz_filter_iter = kz_filter_iter + self.kz_filter_length = to_list(kz_filter_length) + self.kz_filter_iter = to_list(kz_filter_iter) self.cutoff_period = None self.cutoff_period_days = None super().__init__(*args, **kwargs) diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 0aae0ad6b212aff2fb914555f9aca83a48dddefb..1aec30b8dd7f4bd837aba6554c9e68b14375bd6c 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -4,15 +4,13 @@ __date__ = '2020-11-05' from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.data_handler.data_handler_kz_filter import DataHandlerKzFilterSingleStation from mlair.data_handler import DefaultDataHandler -from mlair.configuration import path_config from mlair import helpers from mlair.helpers import remove_items from mlair.configuration.defaults import DEFAULT_SAMPLING -import logging -import os import inspect from typing import Callable +import datetime as dt import numpy as np import pandas as pd @@ -39,7 +37,7 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], - self.station_type, self.network, self.store_data_locally) + self.station_type, self.network, self.store_data_locally, start, end) data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) return data @@ -90,6 +88,33 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi self.call_transform() self.make_samples() + def estimate_filter_width(self): + """ + f = 0.5 / (len * sqrt(itr)) -> T = 1 / f + :return: + """ + return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) + + @staticmethod + def _add_time_delta(date, delta): + new_date = dt.datetime.strptime(date, "%Y-%m-%d") + dt.timedelta(hours=delta) + return new_date.strftime("%Y-%m-%d") + + def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: + + if ind == 0: # for inputs + estimated_filter_width = self.estimate_filter_width() + start = self._add_time_delta(self.start, -estimated_filter_width) + end = self._add_time_delta(self.end, estimated_filter_width) + else: # target + start, end = self.start, self.end + + data, self.meta = self.load_data(self.path[ind], self.station, self.statistics_per_var, self.sampling[ind], + self.station_type, self.network, self.store_data_locally, start, end) + data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, + limit=self.interpolation_limit) + return data + class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" @@ -100,6 +125,15 @@ class DataHandlerMixedSamplingWithFilter(DefaultDataHandler): class DataHandlerMixedSamplingSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFilterSingleStation): + """ + Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the + separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). + + ..image::../../../../../ _source / _plots / monthly_summary_box_plot.png + :width: 400 + + """ + _requirements = DataHandlerMixedSamplingWithFilterSingleStation.requirements() def __init__(self, *args, time_delta=np.sqrt, **kwargs): @@ -147,11 +181,21 @@ class DataHandlerMixedSamplingSeparationOfScalesSingleStation(DataHandlerMixedSa res = xr.concat(res, dim="filter").chunk() return res + def estimate_filter_width(self): + """ + Attention: this method returns the maximum value of + * either estimated filter width f = 0.5 / (len * sqrt(itr)) -> T = 1 / f or + * time delta method applied on the estimated filter width mupliplied by window_history_size + to provide a sufficiently wide filter width. + """ + est = self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2 + return int(max([self.time_delta(est) * self.window_history_size, est])) + class DataHandlerMixedSamplingSeparationOfScales(DefaultDataHandler): """Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step sizes are applied in relation to frequencies.""" - data_handler = DataHandlerMixedSamplingWithFilterSingleStation - data_handler_transformation = DataHandlerMixedSamplingWithFilterSingleStation + data_handler = DataHandlerMixedSamplingSeparationOfScalesSingleStation + data_handler_transformation = DataHandlerMixedSamplingSeparationOfScalesSingleStation _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index cd922e7535124b2d83be2ac9aa3e53f5df949ba6..4c274f913ac6668a84793c1d0628728334601d51 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -142,7 +142,7 @@ class DataHandlerSingleStation(AbstractDataHandler): Setup samples. This method prepares and creates samples X, and labels Y. """ data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - self.station_type, self.network, self.store_data_locally) + self.station_type, self.network, self.store_data_locally, self.start, self.end) self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) self.set_inputs_and_targets() @@ -163,7 +163,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.remove_nan(self.time_dim) def load_data(self, path, station, statistics_per_var, sampling, station_type=None, network=None, - store_data_locally=False): + store_data_locally=False, start=None, end=None): """ Load data and meta data either from local disk (preferred) or download new data by using a custom download method. @@ -199,7 +199,7 @@ class DataHandlerSingleStation(AbstractDataHandler): store_data_locally=store_data_locally) logging.debug("loading finished") # create slices and check for negative concentration. - data = self._slice_prep(data) + data = self._slice_prep(data, start=start, end=end) data = self.check_for_negative_concentrations(data) return data, meta @@ -442,7 +442,7 @@ class DataHandlerSingleStation(AbstractDataHandler): self.label = self.label.sel({dim: intersect}) self.observation = self.observation.sel({dim: intersect}) - def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray: + def _slice_prep(self, data: xr.DataArray, start=None, end=None) -> xr.DataArray: """ Set start and end date for slicing and execute self._slice(). @@ -451,9 +451,9 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: sliced data """ - start = self.start if self.start is not None else data.coords[coord][0].values - end = self.end if self.end is not None else data.coords[coord][-1].values - return self._slice(data, start, end, coord) + start = start if start is not None else data.coords[self.time_dim][0].values + end = end if end is not None else data.coords[self.time_dim][-1].values + return self._slice(data, start, end, self.time_dim) @staticmethod def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: diff --git a/test/test_data_handler/test_data_handler.py b/test/test_data_handler/test_data_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..418c7946efe160c9bbfeccff9908a6cf17dec17f --- /dev/null +++ b/test/test_data_handler/test_data_handler.py @@ -0,0 +1,67 @@ +import pytest +import inspect + +from mlair.data_handler.abstract_data_handler import AbstractDataHandler + + +class TestDefaultDataHandler: + + def test_required_attributes(self): + dh = AbstractDataHandler + assert hasattr(dh, "_requirements") + assert hasattr(dh, "__init__") + assert hasattr(dh, "build") + assert hasattr(dh, "requirements") + assert hasattr(dh, "own_args") + assert hasattr(dh, "transformation") + assert hasattr(dh, "get_X") + assert hasattr(dh, "get_Y") + assert hasattr(dh, "get_data") + assert hasattr(dh, "get_coordinates") + + def test_init(self): + assert isinstance(AbstractDataHandler(), AbstractDataHandler) + + def test_build(self): + assert isinstance(AbstractDataHandler.build(), AbstractDataHandler) + + def test_requirements(self): + dh = AbstractDataHandler() + assert isinstance(dh._requirements, list) + assert len(dh._requirements) == 0 + assert isinstance(dh.requirements(), list) + assert len(dh.requirements()) == 0 + + def test_own_args(self): + dh = AbstractDataHandler() + assert isinstance(dh.own_args(), list) + assert len(dh.own_args()) == 0 + assert "self" not in dh.own_args() + + def test_transformation(self): + assert AbstractDataHandler.transformation() is None + + def test_get_X(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_X() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_X).args) + assert (False, False) == inspect.getfullargspec(dh.get_X).defaults + + def test_get_Y(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_Y() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_Y).args) + assert (False, False) == inspect.getfullargspec(dh.get_Y).defaults + + def test_get_data(self): + dh = AbstractDataHandler() + with pytest.raises(NotImplementedError): + dh.get_data() + assert sorted(["self", "upsampling", "as_numpy"]) == sorted(inspect.getfullargspec(dh.get_data).args) + assert (False, False) == inspect.getfullargspec(dh.get_data).defaults + + def test_get_coordinates(self): + dh = AbstractDataHandler() + assert dh.get_coordinates() is None