diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index 36d6e9ae5394705af4b9fbcfd1d8ff77572642b5..9ea163fcad2890580e9c44e4bda0627d6419dc9f 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -5,13 +5,14 @@ __date__ = '2020-09-21' import inspect from typing import Union, Dict -from mlair.helpers import remove_items +from mlair.helpers import remove_items, to_list -class AbstractDataHandler: +class AbstractDataHandler(object): _requirements = [] _store_attributes = [] + _skip_args = ["self"] def __init__(self, *args, **kwargs): pass @@ -22,16 +23,28 @@ class AbstractDataHandler: return cls(*args, **kwargs) @classmethod - def requirements(cls): + def requirements(cls, skip_args=None): """Return requirements and own arguments without duplicates.""" - return list(set(cls._requirements + cls.own_args())) + skip_args = cls._skip_args if skip_args is None else cls._skip_args + to_list(skip_args) + return remove_items(list(set(cls._requirements + cls.own_args())), skip_args) @classmethod def own_args(cls, *args): """Return all arguments (including kwonlyargs).""" arg_spec = inspect.getfullargspec(cls) - list_of_args = arg_spec.args + arg_spec.kwonlyargs - return remove_items(list_of_args, ["self"] + list(args)) + list_of_args = arg_spec.args + arg_spec.kwonlyargs + cls.super_args() + return list(set(remove_items(list_of_args, list(args)))) + + @classmethod + def super_args(cls): + args = [] + for super_cls in cls.__mro__: + if super_cls == cls: + continue + if hasattr(super_cls, "own_args"): + # args.extend(super_cls.own_args()) + args.extend(getattr(super_cls, "own_args")()) + return list(set(args)) @classmethod def store_attributes(cls) -> list: diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 5aefb0368ec1cf544443bb5e0412dd16a97f2a7f..eb3f78dc465247095d0114f3f41d4b8b70ba5480 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -2,30 +2,25 @@ __author__ = 'Lukas Leufen' __date__ = '2020-11-05' from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation -from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation, \ - DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation -from mlair.data_handler.data_handler_with_filter import DataHandlerClimateFirFilter, DataHandlerFirFilter, \ - DataHandlerKzFilter +from mlair.data_handler.data_handler_with_filter import DataHandlerFirFilterSingleStation, \ + DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation +from mlair.data_handler.data_handler_with_filter import DataHandlerClimateFirFilter, DataHandlerFirFilter from mlair.data_handler import DefaultDataHandler from mlair import helpers -from mlair.helpers import remove_items +from mlair.helpers import to_list from mlair.configuration.defaults import DEFAULT_SAMPLING, DEFAULT_INTERPOLATION_LIMIT, DEFAULT_INTERPOLATION_METHOD from mlair.helpers.filter import filter_width_kzf import copy -import inspect -from typing import Callable import datetime as dt from typing import Any from functools import partial -import numpy as np import pandas as pd import xarray as xr class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): - _requirements = remove_items(inspect.getfullargspec(DataHandlerSingleStation).args, ["self", "station"]) def __init__(self, *args, **kwargs): """ @@ -101,9 +96,6 @@ class DataHandlerMixedSampling(DefaultDataHandler): class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSingleStation, DataHandlerFilterSingleStation): - _requirements1 = DataHandlerFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -111,6 +103,16 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def _check_sampling(self, **kwargs): assert kwargs.get("sampling") == ("hourly", "daily") + def apply_filter(self): + raise NotImplementedError + + def create_filter_index(self) -> pd.Index: + """Create name for filter dimension.""" + raise NotImplementedError + + def _create_lazy_data(self): + raise NotImplementedError + def make_input_target(self): """ A FIR filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values @@ -159,46 +161,31 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi self.target_data = self._slice_prep(_target_data, self.start, self.end) -class DataHandlerMixedSamplingWithKzFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, - DataHandlerKzFilterSingleStation): - _requirements1 = DataHandlerKzFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) - - def estimate_filter_width(self): - """ - f = 0.5 / (len * sqrt(itr)) -> T = 1 / f - :return: - """ - return int(self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2) - - def _extract_lazy(self, lazy_data): - _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days, \ - self.filter_dim_order = lazy_data - super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) - - -class DataHandlerMixedSamplingWithKzFilter(DataHandlerKzFilter): - """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" - - data_handler = DataHandlerMixedSamplingWithKzFilterSingleStation - data_handler_transformation = DataHandlerMixedSamplingWithKzFilterSingleStation - _requirements = data_handler.requirements() - - class DataHandlerMixedSamplingWithFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, DataHandlerFirFilterSingleStation): - _requirements1 = DataHandlerFirFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def estimate_filter_width(self): """Filter width is determined by the filter with the highest order.""" - return max(self.filter_order) + if isinstance(self.filter_order[0], tuple): + return max([filter_width_kzf(*e) for e in self.filter_order]) + else: + return max(self.filter_order) + + def apply_filter(self): + DataHandlerFirFilterSingleStation.apply_filter(self) + + def create_filter_index(self, add_unfiltered_index=True) -> pd.Index: + return DataHandlerFirFilterSingleStation.create_filter_index(self, add_unfiltered_index=add_unfiltered_index) def _extract_lazy(self, lazy_data): _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data - super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + DataHandlerMixedSamplingWithFilterSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data)) + + def _create_lazy_data(self): + return DataHandlerFirFilterSingleStation._create_lazy_data(self) @staticmethod def _get_fs(**kwargs): @@ -220,18 +207,8 @@ class DataHandlerMixedSamplingWithFirFilter(DataHandlerFirFilter): _requirements = data_handler.requirements() -class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerMixedSamplingWithFilterSingleStation, - DataHandlerClimateFirFilterSingleStation): - _requirements1 = DataHandlerClimateFirFilterSingleStation.requirements() - _requirements2 = DataHandlerMixedSamplingWithFilterSingleStation.requirements() - _requirements = list(set(_requirements1 + _requirements2)) - - def estimate_filter_width(self): - """Filter width is determined by the filter with the highest order.""" - if isinstance(self.filter_order[0], tuple): - return max([filter_width_kzf(*e) for e in self.filter_order]) - else: - return max(self.filter_order) +class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerClimateFirFilterSingleStation, + DataHandlerMixedSamplingWithFirFilterSingleStation): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -241,17 +218,6 @@ class DataHandlerMixedSamplingWithClimateFirFilterSingleStation(DataHandlerMixed self.filter_dim_order = lazy_data DataHandlerMixedSamplingWithFilterSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data)) - @staticmethod - def _get_fs(**kwargs): - """Return frequency in 1/day (not Hz)""" - sampling = kwargs.get("sampling")[0] - if sampling == "daily": - return 1 - elif sampling == "hourly": - return 24 - else: - raise ValueError(f"Unknown sampling rate {sampling}. Only daily and hourly resolution is supported.") - class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): """Data handler using mixed sampling for input and target. Inputs are temporal filtered.""" @@ -268,19 +234,11 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): self.filter_add_unfiltered = filter_add_unfiltered super().__init__(*args, **kwargs) - @classmethod - def own_args(cls, *args): - """Return all arguments (including kwonlyargs).""" - super_own_args = DataHandlerClimateFirFilter.own_args(*args) - arg_spec = inspect.getfullargspec(cls) - list_of_args = arg_spec.args + arg_spec.kwonlyargs + super_own_args - return remove_items(list_of_args, ["self"] + list(args)) - def _create_collection(self): + collection = super()._create_collection() if self.filter_add_unfiltered is True and self.dh_unfiltered is not None: - return [self.id_class, self.dh_unfiltered] - else: - return super()._create_collection() + collection.append(self.dh_unfiltered) + return collection @classmethod def build(cls, station: str, **kwargs): @@ -306,19 +264,23 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): return kwargs_dict @classmethod - def transformation(cls, set_stations, tmp_path=None, **kwargs): + def transformation(cls, set_stations, tmp_path=None, dh_transformation=None, **kwargs): - sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} - if "transformation" not in sp_keys.keys(): + # sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs} + if "transformation" not in kwargs.keys(): return + if dh_transformation is None: + dh_transformation = (cls.data_handler_transformation, cls.data_handler_unfiltered) + elif not isinstance(dh_transformation, tuple): + dh_transformation = (dh_transformation, dh_transformation) transformation_filtered = super().transformation(set_stations, tmp_path=tmp_path, - dh_transformation=cls.data_handler_transformation, **kwargs) + dh_transformation=dh_transformation[0], **kwargs) if kwargs.get("filter_add_unfiltered", False) is False: return transformation_filtered else: transformation_unfiltered = super().transformation(set_stations, tmp_path=tmp_path, - dh_transformation=cls.data_handler_unfiltered, **kwargs) + dh_transformation=dh_transformation[1], **kwargs) return {"filtered": transformation_filtered, "unfiltered": transformation_unfiltered} def get_X_original(self): @@ -337,80 +299,222 @@ class DataHandlerMixedSamplingWithClimateFirFilter(DataHandlerClimateFirFilter): return super().get_X_original() -class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation): - """ - Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the - separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). +class DataHandlerMixedSamplingWithClimateAndFirFilter(DataHandlerMixedSamplingWithClimateFirFilter): + # data_handler = DataHandlerMixedSamplingWithClimateFirFilterSingleStation + # data_handler_transformation = DataHandlerMixedSamplingWithClimateFirFilterSingleStation + # data_handler_unfiltered = DataHandlerMixedSamplingSingleStation + # _requirements = list(set(data_handler.requirements() + data_handler_unfiltered.requirements())) + # DEFAULT_FILTER_ADD_UNFILTERED = False + data_handler_climate_fir = DataHandlerMixedSamplingWithClimateFirFilterSingleStation + data_handler_fir = DataHandlerMixedSamplingWithFirFilterSingleStation + data_handler = None + data_handler_unfiltered = DataHandlerMixedSamplingSingleStation + _requirements = list(set(data_handler_climate_fir.requirements() + data_handler_fir.requirements() + + data_handler_unfiltered.requirements())) - .. image:: ../../../../../_source/_plots/separation_of_scales.png - :width: 400 + def __init__(self, data_handler_class_chem, data_handler_class_meteo, data_handler_class_chem_unfiltered, + data_handler_class_meteo_unfiltered, chem_vars, meteo_vars, *args, **kwargs): - """ + if len(chem_vars) > 0: + id_class, id_class_unfiltered = data_handler_class_chem, data_handler_class_chem_unfiltered + self.id_class_other = data_handler_class_meteo + self.id_class_other_unfiltered = data_handler_class_meteo_unfiltered + else: + id_class, id_class_unfiltered = data_handler_class_meteo, data_handler_class_meteo_unfiltered + self.id_class_other = data_handler_class_chem + self.id_class_other_unfiltered = data_handler_class_chem_unfiltered + super().__init__(id_class, *args, data_handler_class_unfiltered=id_class_unfiltered, **kwargs) - _requirements = DataHandlerMixedSamplingWithKzFilterSingleStation.requirements() - _hash = DataHandlerMixedSamplingWithKzFilterSingleStation._hash + ["time_delta"] + @classmethod + def _split_chem_and_meteo_variables(cls, **kwargs): + if "variables" in kwargs: + variables = kwargs.get("variables") + elif "statistics_per_var" in kwargs: + variables = kwargs.get("statistics_per_var") + else: + variables = None + if variables is None: + variables = cls.data_handler_climate_fir.DEFAULT_VAR_ALL_DICT.keys() + chem_vars = cls.data_handler_climate_fir.chem_vars + chem = set(variables).intersection(chem_vars) + meteo = set(variables).difference(chem_vars) + return to_list(chem), to_list(meteo) - def __init__(self, *args, time_delta=np.sqrt, **kwargs): - assert isinstance(time_delta, Callable) - self.time_delta = time_delta - super().__init__(*args, **kwargs) + @classmethod + def build(cls, station: str, **kwargs): + chem_vars, meteo_vars = cls._split_chem_and_meteo_variables(**kwargs) + filter_add_unfiltered = kwargs.get("filter_add_unfiltered", False) + sp_chem, sp_chem_unfiltered = None, None + sp_meteo, sp_meteo_unfiltered = None, None + + if len(chem_vars) > 0: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_climate_fir.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="filtered_chem") + sp_keys.update({"variables": chem_vars}) + cls.adjust_window_opts("chem", "window_history_size", sp_keys) + cls.adjust_window_opts("chem", "window_history_offset", sp_keys) + sp_chem = cls.data_handler_climate_fir(station, **sp_keys) + if filter_add_unfiltered is True: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="unfiltered_chem") + sp_keys.update({"variables": chem_vars}) + cls.adjust_window_opts("chem", "window_history_size", sp_keys) + cls.adjust_window_opts("chem", "window_history_offset", sp_keys) + sp_chem_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) + if len(meteo_vars) > 0: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_fir.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="filtered_meteo") + sp_keys.update({"variables": meteo_vars}) + cls.adjust_window_opts("meteo", "window_history_size", sp_keys) + cls.adjust_window_opts("meteo", "window_history_offset", sp_keys) + sp_meteo = cls.data_handler_fir(station, **sp_keys) + if filter_add_unfiltered is True: + sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls.data_handler_unfiltered.requirements() if k in kwargs} + sp_keys = cls.build_update_kwargs(sp_keys, dh_type="unfiltered_meteo") + sp_keys.update({"variables": meteo_vars}) + cls.adjust_window_opts("meteo", "window_history_size", sp_keys) + cls.adjust_window_opts("meteo", "window_history_offset", sp_keys) + sp_meteo_unfiltered = cls.data_handler_unfiltered(station, **sp_keys) - def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: - """ - Create a xr.DataArray containing history data. + dp_args = {k: copy.deepcopy(kwargs[k]) for k in cls.own_args("id_class") if k in kwargs} + return cls(sp_chem, sp_meteo, sp_chem_unfiltered, sp_meteo_unfiltered, chem_vars, meteo_vars, **dp_args) - Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted - data. This is used to represent history in the data. Results are stored in history attribute. + @staticmethod + def adjust_window_opts(key: str, parameter_name: str, kwargs: dict): + if parameter_name in kwargs: + window_opt = kwargs.pop(parameter_name) + if isinstance(window_opt, dict): + window_opt = window_opt[key] + kwargs[parameter_name] = window_opt - :param dim_name_of_inputs: Name of dimension which contains the input variables - :param window: number of time steps to look back in history - Note: window will be treated as negative value. This should be in agreement with looking back on - a time line. Nonetheless positive values are allowed but they are converted to its negative - expression - :param dim_name_of_shift: Dimension along shift will be applied - """ - window = -abs(window) - data = self.input_data - self.history = self.stride(data, dim_name_of_shift, window, offset=self.window_history_offset) - - def stride(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray: - - # this is just a code snippet to check the results of the kz filter - # import matplotlib - # matplotlib.use("TkAgg") - # import matplotlib.pyplot as plt - # xr.concat(res, dim="filter").sel({"variables":"temp", "Stations":"DEBW107", "datetime":"2010-01-01T00:00:00"}).plot.line(hue="filter") - - time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) - start, end = window, 1 - res = [] - _range = list(map(lambda x: x + offset, range(start, end))) - window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim) - for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): - res_filter = [] - data_filter = data.sel({"filter": filter_name}) - for w in _range: - res_filter.append(data_filter.shift({dim: -(w - offset) * delta - offset})) - res_filter = xr.concat(res_filter, dim=window_array).chunk() - res.append(res_filter) - res = xr.concat(res, dim="filter").compute() - return res + def _create_collection(self): + collection = super()._create_collection() + if self.id_class_other is not None: + collection.append(self.id_class_other) + if self.filter_add_unfiltered is True and self.id_class_other_unfiltered is not None: + collection.append(self.id_class_other_unfiltered) + return collection - def estimate_filter_width(self): - """ - Attention: this method returns the maximum value of - * either estimated filter width f = 0.5 / (len * sqrt(itr)) -> T = 1 / f or - * time delta method applied on the estimated filter width mupliplied by window_history_size - to provide a sufficiently wide filter width. - """ - est = self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2 - return int(max([self.time_delta(est) * self.window_history_size, est])) + @classmethod + def transformation(cls, set_stations, tmp_path=None, **kwargs): + if "transformation" not in kwargs.keys(): + return -class DataHandlerSeparationOfScales(DefaultDataHandler): - """Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step - sizes are applied in relation to frequencies.""" + chem_vars, meteo_vars = cls._split_chem_and_meteo_variables(**kwargs) + + # chem transformation + kwargs_chem = copy.deepcopy(kwargs) + kwargs_chem["variables"] = chem_vars + cls.adjust_window_opts("chem", "window_history_size", kwargs_chem) + cls.adjust_window_opts("chem", "window_history_offset", kwargs_chem) + dh_transformation = (cls.data_handler_climate_fir, cls.data_handler_unfiltered) + transformation_chem = super().transformation(set_stations, tmp_path=tmp_path, + dh_transformation=dh_transformation, **kwargs_chem) + + # meteo transformation + kwargs_meteo = copy.deepcopy(kwargs) + kwargs_meteo["variables"] = meteo_vars + cls.adjust_window_opts("meteo", "window_history_size", kwargs_meteo) + cls.adjust_window_opts("meteo", "window_history_offset", kwargs_meteo) + dh_transformation = (cls.data_handler_fir, cls.data_handler_unfiltered) + transformation_meteo = super().transformation(set_stations, tmp_path=tmp_path, + dh_transformation=dh_transformation, **kwargs_meteo) + + # combine all transformations + transformation_res = {} + if len(transformation_chem) > 0: + transformation_res["filtered_chem"] = transformation_chem.pop("filtered") + transformation_res["unfiltered_chem"] = transformation_chem.pop("unfiltered") + if len(transformation_meteo) > 0: + transformation_res["filtered_meteo"] = transformation_meteo.pop("filtered") + transformation_res["unfiltered_meteo"] = transformation_meteo.pop("unfiltered") + return transformation_res if len(transformation_res) > 0 else None - data_handler = DataHandlerSeparationOfScalesSingleStation - data_handler_transformation = DataHandlerSeparationOfScalesSingleStation - _requirements = data_handler.requirements() + def get_X_original(self): + if self.use_filter_branches is True: + X = [] + for data in self._collection: + if hasattr(data, "filter_dim"): + X_total = data.get_X() + filter_dim = data.filter_dim + for filter_name in data.filter_dim_order: + X.append(X_total.sel({filter_dim: filter_name}, drop=True)) + else: + X.append(data.get_X()) + return X + else: + return super().get_X_original() + + +# +# class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithKzFilterSingleStation): +# """ +# Data handler using mixed sampling for input and target. Inputs are temporal filtered and depending on the +# separation frequency of a filtered time series the time step delta for input data is adjusted (see image below). +# +# .. image:: ../../../../../_source/_plots/separation_of_scales.png +# :width: 400 +# +# """ +# +# _requirements = DataHandlerMixedSamplingWithKzFilterSingleStation.requirements() +# _hash = DataHandlerMixedSamplingWithKzFilterSingleStation._hash + ["time_delta"] +# +# def __init__(self, *args, time_delta=np.sqrt, **kwargs): +# assert isinstance(time_delta, Callable) +# self.time_delta = time_delta +# super().__init__(*args, **kwargs) +# +# def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: +# """ +# Create a xr.DataArray containing history data. +# +# Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted +# data. This is used to represent history in the data. Results are stored in history attribute. +# +# :param dim_name_of_inputs: Name of dimension which contains the input variables +# :param window: number of time steps to look back in history +# Note: window will be treated as negative value. This should be in agreement with looking back on +# a time line. Nonetheless positive values are allowed but they are converted to its negative +# expression +# :param dim_name_of_shift: Dimension along shift will be applied +# """ +# window = -abs(window) +# data = self.input_data +# self.history = self.stride(data, dim_name_of_shift, window, offset=self.window_history_offset) +# +# def stride(self, data: xr.DataArray, dim: str, window: int, offset: int = 0) -> xr.DataArray: +# time_deltas = np.round(self.time_delta(self.cutoff_period)).astype(int) +# start, end = window, 1 +# res = [] +# _range = list(map(lambda x: x + offset, range(start, end))) +# window_array = self.create_index_array(self.window_dim, _range, squeeze_dim=self.target_dim) +# for delta, filter_name in zip(np.append(time_deltas, 1), data.coords["filter"]): +# res_filter = [] +# data_filter = data.sel({"filter": filter_name}) +# for w in _range: +# res_filter.append(data_filter.shift({dim: -(w - offset) * delta - offset})) +# res_filter = xr.concat(res_filter, dim=window_array).chunk() +# res.append(res_filter) +# res = xr.concat(res, dim="filter").compute() +# return res +# +# def estimate_filter_width(self): +# """ +# Attention: this method returns the maximum value of +# * either estimated filter width f = 0.5 / (len * sqrt(itr)) -> T = 1 / f or +# * time delta method applied on the estimated filter width mupliplied by window_history_size +# to provide a sufficiently wide filter width. +# """ +# est = self.kz_filter_length[0] * np.sqrt(self.kz_filter_iter[0]) * 2 +# return int(max([self.time_delta(est) * self.window_history_size, est])) +# +# +# class DataHandlerSeparationOfScales(DefaultDataHandler): +# """Data handler using mixed sampling for input and target. Inputs are temporal filtered and different time step +# sizes are applied in relation to frequencies.""" +# +# data_handler = DataHandlerSeparationOfScalesSingleStation +# data_handler_transformation = DataHandlerSeparationOfScalesSingleStation +# _requirements = data_handler.requirements() diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py index 88a57d108e4533968eeb9a65aabf575fae085704..c21d5b4126b8bc7564ef8855d7c1229c7f411df3 100644 --- a/mlair/data_handler/data_handler_single_station.py +++ b/mlair/data_handler/data_handler_single_station.py @@ -48,12 +48,14 @@ class DataHandlerSingleStation(AbstractDataHandler): DEFAULT_SAMPLING = "daily" DEFAULT_INTERPOLATION_LIMIT = 0 DEFAULT_INTERPOLATION_METHOD = "linear" + chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", + "so2", "toluene"] _hash = ["station", "statistics_per_var", "data_origin", "station_type", "network", "sampling", "target_dim", "target_var", "time_dim", "iter_dim", "window_dim", "window_history_size", "window_history_offset", "window_lead_time", "interpolation_limit", "interpolation_method"] - def __init__(self, station, data_path, statistics_per_var, station_type=DEFAULT_STATION_TYPE, + def __init__(self, station, data_path, statistics_per_var=None, station_type=DEFAULT_STATION_TYPE, network=DEFAULT_NETWORK, sampling: Union[str, Tuple[str]] = DEFAULT_SAMPLING, target_dim=DEFAULT_TARGET_DIM, target_var=DEFAULT_TARGET_VAR, time_dim=DEFAULT_TIME_DIM, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM, @@ -72,7 +74,7 @@ class DataHandlerSingleStation(AbstractDataHandler): if self.lazy is True: self.lazy_path = os.path.join(data_path, "lazy_data", self.__class__.__name__) check_path_and_create(self.lazy_path) - self.statistics_per_var = statistics_per_var + self.statistics_per_var = statistics_per_var or self.DEFAULT_VAR_ALL_DICT self.data_origin = data_origin self.do_transformation = transformation is not None self.input_data, self.target_data = None, None @@ -415,9 +417,7 @@ class DataHandlerSingleStation(AbstractDataHandler): :return: corrected data """ - chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", - "propane", "so2", "toluene"] - used_chem_vars = list(set(chem_vars) & set(data.coords[self.target_dim].values)) + used_chem_vars = list(set(self.chem_vars) & set(data.coords[self.target_dim].values)) if len(used_chem_vars) > 0: data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data @@ -750,24 +750,3 @@ class DataHandlerSingleStation(AbstractDataHandler): def _get_hash(self): hash = "".join([str(self.__getattribute__(e)) for e in self._hash_list()]).encode() return hashlib.md5(hash).hexdigest() - - -if __name__ == "__main__": - statistics_per_var = {'o3': 'dma8eu', 'temp-rea-miub': 'maximum'} - sp = DataHandlerSingleStation(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122', - statistics_per_var=statistics_per_var, station_type='background', - network='UBA', sampling='daily', target_dim='variables', target_var='o3', - time_dim='datetime', window_history_size=7, window_lead_time=3, - interpolation_limit=0 - ) # transformation={'method': 'standardise'}) - sp2 = DataHandlerSingleStation(data_path='/home/felix/PycharmProjects/mlt_new/data/', station='DEBY122', - statistics_per_var=statistics_per_var, station_type='background', - network='UBA', sampling='daily', target_dim='variables', target_var='o3', - time_dim='datetime', window_history_size=7, window_lead_time=3, - transformation={'method': 'standardise'}) - sp2.transform(inverse=True) - sp.get_X() - sp.get_Y() - print(len(sp)) - print(sp.shape) - print(sp) diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index 4707fd580562a68fd6b2dc0843551905e70d7e50..07fdc41fc4dae49bd44a071dd2228c4bff860b04 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -3,7 +3,6 @@ __author__ = 'Lukas Leufen' __date__ = '2020-08-26' -import inspect import copy import numpy as np import pandas as pd @@ -13,8 +12,7 @@ from functools import partial import logging from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation from mlair.data_handler import DefaultDataHandler -from mlair.helpers import remove_items, to_list, TimeTrackingWrapper, statistics -from mlair.helpers.filter import KolmogorovZurbenkoFilterMovingWindow as KZFilter +from mlair.helpers import to_list, TimeTrackingWrapper, statistics from mlair.helpers.filter import FIRFilter, ClimateFIRFilter, omega_null_kzf # define a more general date type for type hinting @@ -40,7 +38,6 @@ str_or_list = Union[str, List[str]] class DataHandlerFilterSingleStation(DataHandlerSingleStation): """General data handler for a single station to be used by a superior data handler.""" - _requirements = remove_items(DataHandlerSingleStation.requirements(), "station") _hash = DataHandlerSingleStation._hash + ["filter_dim"] DEFAULT_FILTER_DIM = "filter" @@ -119,24 +116,15 @@ class DataHandlerFilter(DefaultDataHandler): self.use_filter_branches = use_filter_branches super().__init__(*args, **kwargs) - @classmethod - def own_args(cls, *args): - """Return all arguments (including kwonlyargs).""" - super_own_args = DefaultDataHandler.own_args(*args) - arg_spec = inspect.getfullargspec(cls) - list_of_args = arg_spec.args + arg_spec.kwonlyargs + super_own_args - return remove_items(list_of_args, ["self"] + list(args)) - class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): """Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered.""" - _requirements = remove_items(DataHandlerFilterSingleStation.requirements(), "station") _hash = DataHandlerFilterSingleStation._hash + ["filter_cutoff_period", "filter_order", "filter_window_type"] DEFAULT_WINDOW_TYPE = ("kaiser", 5) - def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, **kwargs): + def __init__(self, *args, filter_cutoff_period, filter_order, filter_window_type=DEFAULT_WINDOW_TYPE, plot_path=None, **kwargs): # self.original_data = None # ToDo: implement here something to store unfiltered data self.fs = self._get_fs(**kwargs) if filter_window_type == "kzf": @@ -147,6 +135,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): self.filter_order = self._prepare_filter_order(filter_order, removed_index, self.fs) self.filter_window_type = filter_window_type self.unfiltered_name = "unfiltered" + self.plot_path = plot_path # use this path to create insight plots super().__init__(*args, **kwargs) @staticmethod @@ -165,14 +154,11 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): @staticmethod def _prepare_filter_cutoff_period(filter_cutoff_period, fs): """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair.""" - cutoff_tmp = (lambda x: [x] if isinstance(x, tuple) else to_list(x))(filter_cutoff_period) cutoff = [] removed = [] - for i, (low, high) in enumerate(cutoff_tmp): - low = low if (low is None or low > 2. / fs) else None - high = high if (high is None or high > 2. / fs) else None - if any([low, high]): - cutoff.append((low, high)) + for i, period in enumerate(to_list(filter_cutoff_period)): + if period > 2. / fs: + cutoff.append(period) else: removed.append(i) return cutoff, removed @@ -187,8 +173,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): @staticmethod def _period_to_freq(cutoff_p): - return list(map(lambda x: (1. / x[0] if x[0] is not None else None, 1. / x[1] if x[1] is not None else None), - cutoff_p)) + return [1. / x for x in cutoff_p] @staticmethod def _get_fs(**kwargs): @@ -205,10 +190,11 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): def apply_filter(self): """Apply FIR filter only on inputs.""" fir = FIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq, - self.filter_window_type, self.target_dim) - self.fir_coeff = fir.filter_coefficients() - fir_data = fir.filtered_data() - self.input_data = xr.concat(fir_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + self.filter_window_type, self.target_dim, self.time_dim, station_name=self.station[0], + minimum_length=self.window_history_size, offset=self.window_history_offset, plot_path=self.plot_path) + self.fir_coeff = fir.filter_coefficients + filter_data = fir.filtered_data + self.input_data = xr.concat(filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) # this is just a code snippet to check the results of the kz filter # import matplotlib # matplotlib.use("TkAgg") @@ -216,22 +202,17 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") - def create_filter_index(self) -> pd.Index: + def create_filter_index(self, add_unfiltered_index=True) -> pd.Index: """ - Create name for filter dimension. Use 'high' or 'low' for high/low pass data and 'bandi' for band pass data with - increasing numerator i (starting from 1). If 1 low, 2 band, and 1 high pass filter is used the filter index will - become to ['low', 'band1', 'band2', 'high']. + Round cut off periods in days and append 'res' for residuum index. + + Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append + 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition. """ - index = [] - band_num = 1 - for (low, high) in self.filter_cutoff_period: - if low is None: - index.append("low") - elif high is None: - index.append("high") - else: - index.append(f"band{band_num}") - band_num += 1 + index = np.round(self.filter_cutoff_period, 1) + f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) + index = list(map(f, index.tolist())) + index = list(map(lambda x: str(x) + "d", index)) + ["res"] self.filter_dim_order = index return pd.Index(index, name=self.filter_dim) @@ -240,7 +221,7 @@ class DataHandlerFirFilterSingleStation(DataHandlerFilterSingleStation): def _extract_lazy(self, lazy_data): _data, _meta, _input_data, _target_data, self.fir_coeff, self.filter_dim_order = lazy_data - super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) + super()._extract_lazy((_data, _meta, _input_data, _target_data)) def transform(self, data_in, dim: Union[str, int] = 0, inverse: bool = False, opts=None, transformation_dim=None): @@ -325,67 +306,6 @@ class DataHandlerFirFilter(DataHandlerFilter): data_handler = DataHandlerFirFilterSingleStation data_handler_transformation = DataHandlerFirFilterSingleStation - - -class DataHandlerKzFilterSingleStation(DataHandlerFilterSingleStation): - """Data handler for a single station to be used by a superior data handler. Inputs are kz filtered.""" - - _requirements = remove_items(inspect.getfullargspec(DataHandlerFilterSingleStation).args, ["self", "station"]) - _hash = DataHandlerFilterSingleStation._hash + ["kz_filter_length", "kz_filter_iter"] - - def __init__(self, *args, kz_filter_length, kz_filter_iter, **kwargs): - self._check_sampling(**kwargs) - # self.original_data = None # ToDo: implement here something to store unfiltered data - self.kz_filter_length = to_list(kz_filter_length) - self.kz_filter_iter = to_list(kz_filter_iter) - self.cutoff_period = None - self.cutoff_period_days = None - super().__init__(*args, **kwargs) - - @TimeTrackingWrapper - def apply_filter(self): - """Apply kolmogorov zurbenko filter only on inputs.""" - kz = KZFilter(self.input_data, wl=self.kz_filter_length, itr=self.kz_filter_iter, filter_dim=self.time_dim) - filtered_data: List[xr.DataArray] = kz.run() - self.cutoff_period = kz.period_null() - self.cutoff_period_days = kz.period_null_days() - self.input_data = xr.concat(filtered_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) - # this is just a code snippet to check the results of the kz filter - # import matplotlib - # matplotlib.use("TkAgg") - # import matplotlib.pyplot as plt - # self.input_data.sel(filter="74d", variables="temp", Stations="DEBW107").plot() - # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") - - def create_filter_index(self) -> pd.Index: - """ - Round cut off periods in days and append 'res' for residuum index. - - Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append - 'res' for residuum index. - """ - index = np.round(self.cutoff_period_days, 1) - f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) - index = list(map(f, index.tolist())) - index = list(map(lambda x: str(x) + "d", index)) + ["res"] - self.filter_dim_order = index - return pd.Index(index, name=self.filter_dim) - - def _create_lazy_data(self): - return [self._data, self.meta, self.input_data, self.target_data, self.cutoff_period, self.cutoff_period_days, - self.filter_dim_order] - - def _extract_lazy(self, lazy_data): - _data, _meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days, \ - self.filter_dim_order = lazy_data - super(__class__, self)._extract_lazy((_data, _meta, _input_data, _target_data)) - - -class DataHandlerKzFilter(DataHandlerFilter): - """Data handler using kz filtered data.""" - - data_handler = DataHandlerKzFilterSingleStation - data_handler_transformation = DataHandlerKzFilterSingleStation _requirements = data_handler.requirements() @@ -407,21 +327,20 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation :param apriori_diurnal: use diurnal anomalies of each hour as addition to the apriori information type chosen by parameter apriori_type. This is only applicable for hourly resolution data. """ - - _requirements = remove_items(DataHandlerFirFilterSingleStation.requirements(), "station") - _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal"] + _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts", "apriori_diurnal", + "extend_length_opts"] _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"] def __init__(self, *args, apriori=None, apriori_type=None, apriori_diurnal=False, apriori_sel_opts=None, - plot_path=None, name_affix=None, **kwargs): + name_affix=None, extend_length_opts=None, **kwargs): self.apriori_type = apriori_type self.climate_filter_coeff = None # coefficents of the used FIR filter self.apriori = apriori # exogenous apriori information or None to calculate from data (endogenous) self.apriori_diurnal = apriori_diurnal self.all_apriori = None # collection of all apriori information self.apriori_sel_opts = apriori_sel_opts # ensure to separate exogenous and endogenous information - self.plot_path = plot_path # use this path to create insight plots self.plot_name_affix = name_affix + self.extend_length_opts = extend_length_opts if extend_length_opts is not None else {} super().__init__(*args, **kwargs) @TimeTrackingWrapper @@ -429,14 +348,14 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation """Apply FIR filter only on inputs.""" self.apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori logging.info(f"{self.station}: call ClimateFIRFilter") - plot_name = str(self) # if self.plot_name_affix is None else f"{str(self)}_{self.plot_name_affix}" climate_filter = ClimateFIRFilter(self.input_data.astype("float32"), self.fs, self.filter_order, self.filter_cutoff_freq, self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim, apriori_type=self.apriori_type, apriori=self.apriori, apriori_diurnal=self.apriori_diurnal, sel_opts=self.apriori_sel_opts, - plot_path=self.plot_path, plot_name=plot_name, - minimum_length=self.window_history_size, new_dim=self.window_dim) + plot_path=self.plot_path, + minimum_length=self.window_history_size, new_dim=self.window_dim, + station_name=self.station[0], extend_length_opts=self.extend_length_opts) self.climate_filter_coeff = climate_filter.filter_coefficients # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori @@ -446,8 +365,18 @@ class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation self.apriori = climate_filter.initial_apriori_data self.all_apriori = climate_filter.apriori_data - climate_filter_data = [c.sel({self.window_dim: slice(-self.window_history_size, 0)}) for c in - climate_filter.filtered_data] + if isinstance(self.extend_length_opts, int): + climate_filter_data = [c.sel({self.window_dim: slice(-self.window_history_size, self.extend_length_opts)}) + for c in climate_filter.filtered_data] + else: + climate_filter_data = [] + for c in climate_filter.filtered_data: + coll_tmp = [] + for v in c.coords[self.target_dim].values: + upper_lim = self.extend_length_opts.get(v, 0) + coll_tmp.append(c.sel({self.target_dim: v, + self.window_dim: slice(-self.window_history_size, upper_lim)})) + climate_filter_data.append(xr.concat(coll_tmp, self.target_dim)) # create input data with filter index input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(add_unfiltered_index=False), diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 9b8efe811d3ca987a9a67765cdde8ac1e73a9cca..d158726e5f433d40cfa272e6a9c7f808057f88e4 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -21,7 +21,7 @@ import numpy as np import xarray as xr from mlair.data_handler.abstract_data_handler import AbstractDataHandler -from mlair.helpers import remove_items, to_list +from mlair.helpers import remove_items, to_list, TimeTrackingWrapper from mlair.helpers.join import EmptyQueryResult @@ -33,8 +33,9 @@ class DefaultDataHandler(AbstractDataHandler): from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation as data_handler_transformation - _requirements = remove_items(inspect.getfullargspec(data_handler).args, ["self", "station"]) + _requirements = data_handler.requirements() _store_attributes = data_handler.store_attributes() + _skip_args = AbstractDataHandler._skip_args + ["id_class"] DEFAULT_ITER_DIM = "Stations" DEFAULT_TIME_DIM = "datetime" @@ -73,10 +74,6 @@ class DefaultDataHandler(AbstractDataHandler): def _create_collection(self): return [self.id_class] - @classmethod - def requirements(cls): - return remove_items(super().requirements(), "id_class") - def _reset_data(self): self._X, self._Y, self._X_extreme, self._Y_extreme = None, None, None, None gc.collect() @@ -164,6 +161,7 @@ class DefaultDataHandler(AbstractDataHandler): self._reset_data() if no_data is True else None return self._to_numpy([Y]) if as_numpy is True else Y + @TimeTrackingWrapper def harmonise_X(self): X_original, Y_original = self.get_X_original(), self.get_Y_original() dim = self.time_dim @@ -186,6 +184,7 @@ class DefaultDataHandler(AbstractDataHandler): def apply_transformation(self, data, base="target", dim=0, inverse=False): return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse) + @TimeTrackingWrapper def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, timedelta: Tuple[int, str] = (1, 'm'), dim=DEFAULT_TIME_DIM): """ diff --git a/mlair/data_handler/input_bootstraps.py b/mlair/data_handler/input_bootstraps.py index 187f09050bb39a953ac58c2b7fca54b6a207aed1..b8ad614f2317e804d415b23308df760f4dd8da7f 100644 --- a/mlair/data_handler/input_bootstraps.py +++ b/mlair/data_handler/input_bootstraps.py @@ -123,11 +123,12 @@ class BootstrapIteratorVariable(BootstrapIterator): _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) for index in range(len(_X)): - single_variable = _X[index].sel({self._dimension: [dimension]}) - bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) - bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, - dims=single_variable.dims) - _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) + if dimension in _X[index].coords[self._dimension]: + single_variable = _X[index].sel({self._dimension: [dimension]}) + bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) + bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, + dims=single_variable.dims) + _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) self._position += 1 except IndexError: raise StopIteration() diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 36c93b04486fc9be013af2c4f34d2b3ee1bd84c2..9a61de715bc02e57a41ab6c2b9a62de7157acf07 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -1,6 +1,6 @@ import gc import warnings -from typing import Union, Callable, Tuple +from typing import Union, Callable, Tuple, Dict, Any import logging import os import time @@ -17,49 +17,157 @@ from mlair.helpers import to_list, TimeTrackingWrapper, TimeTracking class FIRFilter: + from mlair.plotting.data_insight_plotting import PlotFirFilter + + def __init__(self, data, fs, order, cutoff, window, var_dim, time_dim, station_name=None, minimum_length=0, offset=0, plot_path=None): + self._filtered = [] + self._h = [] + self.data = data + self.fs = fs + self.order = order + self.cutoff = cutoff + self.window = window + self.var_dim = var_dim + self.time_dim = time_dim + self.station_name = station_name + self.minimum_length = minimum_length + self.offset = offset + self.plot_path = plot_path + self.run() - def __init__(self, data, fs, order, cutoff, window, dim): - + def run(self): + logging.info(f"{self.station_name}: start {self.__class__.__name__}") filtered = [] h = [] - for i in range(len(order)): - fi, hi = fir_filter(data, fs, order=order[i], cutoff_low=cutoff[i][0], cutoff_high=cutoff[i][1], - window=window, dim=dim, h=None, causal=True, padlen=None) + input_data = self.data.__deepcopy__() + + # collect some data for visualization + plot_pos = np.array([0.25, 1.5, 2.75, 4]) * 365 * self.fs + plot_dates = [input_data.isel({self.time_dim: int(pos)}).coords[self.time_dim].values for pos in plot_pos if + pos < len(input_data.coords[self.time_dim])] + plot_data = [] + + for i in range(len(self.order)): + # apply filter + fi, hi = self.fir_filter(input_data, self.fs, self.cutoff[i], self.order[i], time_dim=self.time_dim, + var_dim=self.var_dim, window=self.window, station_name=self.station_name) filtered.append(fi) h.append(hi) + # visualization + plot_data.extend(self.create_visualization(fi, input_data, plot_dates, self.time_dim, self.fs, hi, + self.minimum_length, self.order, i, self.offset, self.var_dim)) + # calculate residuum + input_data = input_data - fi + + # add last residuum to filtered + filtered.append(input_data) + self._filtered = filtered self._h = h + # visualize + if self.plot_path is not None: + try: + self.PlotFirFilter(self.plot_path, plot_data, self.station_name) # not working when t0 != 0 + except Exception as e: + logging.info(f"Could not plot climate fir filter due to following reason:\n{e}") + + def create_visualization(self, filtered, filter_input_data, plot_dates, time_dim, sampling, + h, minimum_length, order, i, offset, var_dim): # pragma: no cover + plot_data = [] + for viz_date in set(plot_dates).intersection(filtered.coords[time_dim].values): + try: + if i < len(order) - 1: + minimum_length += order[i+1] + + td_type = {1: "D", 24: "h"}.get(sampling) + length = len(h) + extend_length_history = minimum_length + int((length + 1) / 2) + extend_length_future = int((length + 1) / 2) + 1 + t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) + t_plus = viz_date + np.timedelta64(int(extend_length_future + offset), td_type) + time_slice = slice(t_minus, t_plus - np.timedelta64(1, td_type)) + plot_data.append({"t0": viz_date, "filter_input": filter_input_data.sel({time_dim: time_slice}), + "filtered": filtered.sel({time_dim: time_slice}), "h": h, "time_dim": time_dim, + "var_dim": var_dim}) + except: + pass + return plot_data + + @property def filter_coefficients(self): return self._h + @property def filtered_data(self): return self._filtered - # - # y, h = fir_filter(station_data.values.flatten(), fs, order[0], cutoff_low=cutoff[0][0], cutoff_high=cutoff[0][1], - # window=window) - # filtered = xr.ones_like(station_data) * y.reshape(station_data.values.shape) - # # band pass - # y_band, h_band = fir_filter(station_data.values.flatten(), fs, order[1], cutoff_low=cutoff[1][0], - # cutoff_high=cutoff[1][1], window=window) - # filtered_band = xr.ones_like(station_data) * y_band.reshape(station_data.values.shape) - # # band pass 2 - # y_band_2, h_band_2 = fir_filter(station_data.values.flatten(), fs, order[2], cutoff_low=cutoff[2][0], - # cutoff_high=cutoff[2][1], window=window) - # filtered_band_2 = xr.ones_like(station_data) * y_band_2.reshape(station_data.values.shape) - # # high pass - # y_high, h_high = fir_filter(station_data.values.flatten(), fs, order[3], cutoff_low=cutoff[3][0], - # cutoff_high=cutoff[3][1], window=window) - # filtered_high = xr.ones_like(station_data) * y_high.reshape(station_data.values.shape) - - -class ClimateFIRFilter: + + @TimeTrackingWrapper + def fir_filter(self, data, fs, cutoff_high, order, sampling="1d", time_dim="datetime", var_dim="variables", window: Union[str, Tuple] = "hamming", + minimum_length=None, new_dim="window", plot_dates=None, station_name=None): + + # calculate FIR filter coefficients + h = self._calculate_filter_coefficients(window, order, cutoff_high, fs) + + coll = [] + for var in data.coords[var_dim]: + d = data.sel({var_dim: var}) + filt = xr.apply_ufunc(fir_filter_convolve, d, + input_core_dims=[[time_dim]], output_core_dims=[[time_dim]], + vectorize=True, kwargs={"h": h}, output_dtypes=[d.dtype]) + coll.append(filt) + filtered = xr.concat(coll, var_dim) + + # create result array with same shape like input data, gaps are filled by nans + filtered = self._create_full_filter_result_array(data, filtered, time_dim, station_name) + return filtered, h + + @staticmethod + def _calculate_filter_coefficients(window: Union[str, tuple], order: Union[int, tuple], cutoff_high: float, + fs: float) -> np.array: + """ + Calculate filter coefficients for moving window using scipy's signal package for common filter types and local + method firwin_kzf for Kolmogorov Zurbenko filter (kzf). The filter is a low-pass filter. + + :param window: name of the window type which is either a string with the window's name or a tuple containing the + name but also some parameters (e.g. `("kaiser", 5)`) + :param order: order of the filter to create as int or parameters m and k of kzf + :param cutoff_high: cutoff frequency to use for low-pass filter in frequency of fs + :param fs: sampling frequency of time series + """ + if window == "kzf": + h = firwin_kzf(*order) + else: + h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + return h + + @staticmethod + def _create_full_filter_result_array(template_array: xr.DataArray, result_array: xr.DataArray, new_dim: str, + station_name: str = None) -> xr.DataArray: + """ + Create result filter array with same shape line given template data (should be the original input data before + filtering the data). All gaps are filled by nans. + + :param template_array: this array is used as template for shape and ordering of dims + :param result_array: array with data that are filled into template + :param new_dim: new dimension which is shifted/appended to/at the end (if present or not) + :param station_name: string that is attached to logging (default None) + """ + logging.debug(f"{station_name}: create res_full") + new_coords = {**{k: template_array.coords[k].values for k in template_array.coords if k != new_dim}, + new_dim: result_array.coords[new_dim]} + dims = [*template_array.dims, new_dim] if new_dim not in template_array.dims else template_array.dims + result_array = result_array.transpose(*dims) + return result_array.broadcast_like(xr.DataArray(dims=dims, coords=new_coords)) + + +class ClimateFIRFilter(FIRFilter): from mlair.plotting.data_insight_plotting import PlotClimateFirFilter def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None, - apriori_diurnal=False, sel_opts=None, plot_path=None, plot_name=None, - minimum_length=None, new_dim=None): + apriori_diurnal=False, sel_opts=None, plot_path=None, + minimum_length=None, new_dim=None, station_name=None, extend_length_opts: Union[dict, int] = 0): """ :param data: data to filter :param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24 @@ -75,111 +183,120 @@ class ClimateFIRFilter: residua is used ("residuum_stats"). :param apriori_diurnal: Use diurnal cycle as additional apriori information (only applicable for hourly resoluted data). The mean anomaly of each hour is added to the apriori_type information. + :param extend_length_opts: shift information switch between historical data and apriori estimation by the given + values (default None). Must either be a dictionary with keys available in var_dim or a single value that is + applied to all data. """ - logging.info(f"{plot_name}: start init ClimateFIRFilter") + #todo add extend_length_opts + # adjust all parts of code marked as todos + # think about different behaviour when using different extend_length_opts (is this part of dh?) + + self._apriori = apriori + self.apriori_type = apriori_type + self.apriori_diurnal = apriori_diurnal + self._apriori_list = [] + self.sel_opts = sel_opts + self.minimum_length = minimum_length + self.new_dim = new_dim self.plot_path = plot_path - self.plot_name = plot_name self.plot_data = [] + self.extend_length_opts = extend_length_opts + super().__init__(data, fs, order, cutoff, window, var_dim, time_dim, station_name=station_name) + + def run(self): filtered = [] h = [] - if sel_opts is not None: - sel_opts = sel_opts if isinstance(sel_opts, dict) else {time_dim: sel_opts} - sampling = {1: "1d", 24: "1H"}.get(int(fs)) - logging.debug(f"{plot_name}: create diurnal_anomalies") - if apriori_diurnal is True and sampling == "1H": - # diurnal_anomalies = self.create_hourly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, - # as_anomaly=True) - diurnal_anomalies = self.create_seasonal_hourly_mean(data, sel_opts=sel_opts, sampling=sampling, - time_dim=time_dim, - as_anomaly=True) + if self.sel_opts is not None: + self.sel_opts = self.sel_opts if isinstance(self.sel_opts, dict) else {self.time_dim: self.sel_opts} + sampling = {1: "1d", 24: "1H"}.get(int(self.fs)) + logging.debug(f"{self.station_name}: create diurnal_anomalies") + if self.apriori_diurnal is True and sampling == "1H": + diurnal_anomalies = self.create_seasonal_hourly_mean(self.data, self.time_dim, sel_opts=self.sel_opts, + sampling=sampling, as_anomaly=True) else: diurnal_anomalies = 0 - logging.debug(f"{plot_name}: create monthly apriori") - if apriori is None: - apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, - time_dim=time_dim) + diurnal_anomalies - logging.debug(f"{plot_name}: apriori shape = {apriori.shape}") - apriori_list = to_list(apriori) - input_data = data.__deepcopy__() + logging.debug(f"{self.station_name}: create monthly apriori") + if self._apriori is None: + self._apriori = self.create_monthly_mean(self.data, self.time_dim, sel_opts=self.sel_opts, + sampling=sampling) + diurnal_anomalies + logging.debug(f"{self.station_name}: apriori shape = {self._apriori.shape}") + apriori_list = to_list(self._apriori) + input_data = self.data.__deepcopy__() # for viz plot_dates = None # create tmp dimension to apply filter, search for unused name - new_dim = self._create_tmp_dimension(input_data) if new_dim is None else new_dim + new_dim = self._create_tmp_dimension(input_data) if self.new_dim is None else self.new_dim - for i in range(len(order)): - logging.info(f"{plot_name}: start filter for order {order[i]}") + for i in range(len(self.order)): + logging.info(f"{self.station_name}: start filter for order {self.order[i]}") # calculate climatological filter - # ToDo: remove all methods except the vectorized version - _minimum_length = self._minimum_length(order, minimum_length, i, window) - fi, hi, apriori, plot_data = self.clim_filter(input_data, fs, cutoff[i], order[i], - apriori=apriori_list[i], - sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, - window=window, var_dim=var_dim, + _minimum_length = self._minimum_length(self.order, self.minimum_length, i, self.window) + fi, hi, apriori, plot_data = self.clim_filter(input_data, self.fs, self.cutoff[i], self.order[i], + apriori=apriori_list[i], sel_opts=self.sel_opts, + sampling=sampling, time_dim=self.time_dim, + window=self.window, var_dim=self.var_dim, minimum_length=_minimum_length, new_dim=new_dim, - plot_dates=plot_dates) + plot_dates=plot_dates, station_name=self.station_name, + extend_length_opts=self.extend_length_opts) - logging.info(f"{plot_name}: finished clim_filter calculation") - if minimum_length is None: + logging.info(f"{self.station_name}: finished clim_filter calculation") + if self.minimum_length is None: filtered.append(fi) else: - filtered.append(fi.sel({new_dim: slice(-minimum_length, 0)})) + filtered.append(fi.sel({new_dim: slice(-self.minimum_length, None)})) h.append(hi) gc.collect() self.plot_data.append(plot_data) plot_dates = {e["t0"] for e in plot_data} # calculate residuum - logging.info(f"{plot_name}: calculate residuum") + logging.info(f"{self.station_name}: calculate residuum") coord_range = range(fi.coords[new_dim].values.min(), fi.coords[new_dim].values.max() + 1) if new_dim in input_data.coords: input_data = input_data.sel({new_dim: coord_range}) - fi else: - input_data = self._shift_data(input_data, coord_range, time_dim, var_dim, new_dim) - fi + input_data = self._shift_data(input_data, coord_range, self.time_dim, new_dim) - fi # create new apriori information for next iteration if no further apriori is provided if len(apriori_list) <= i + 1: - logging.info(f"{plot_name}: create diurnal_anomalies") - if apriori_diurnal is True and sampling == "1H": - # diurnal_anomalies = self.create_hourly_mean(input_data.sel({new_dim: 0}, drop=True), - # sel_opts=sel_opts, sampling=sampling, - # time_dim=time_dim, as_anomaly=True) + logging.info(f"{self.station_name}: create diurnal_anomalies") + if self.apriori_diurnal is True and sampling == "1H": diurnal_anomalies = self.create_seasonal_hourly_mean(input_data.sel({new_dim: 0}, drop=True), - sel_opts=sel_opts, sampling=sampling, - time_dim=time_dim, as_anomaly=True) + self.time_dim, sel_opts=self.sel_opts, + sampling=sampling, as_anomaly=True) else: diurnal_anomalies = 0 - logging.info(f"{plot_name}: create monthly apriori") - if apriori_type is None or apriori_type == "zeros": # zero version + logging.info(f"{self.station_name}: create monthly apriori") + if self.apriori_type is None or self.apriori_type == "zeros": # zero version apriori_list.append(xr.zeros_like(apriori_list[i]) + diurnal_anomalies) - elif apriori_type == "residuum_stats": # calculate monthly statistic on residuum + elif self.apriori_type == "residuum_stats": # calculate monthly statistic on residuum apriori_list.append( - -self.create_monthly_mean(input_data.sel({new_dim: 0}, drop=True), sel_opts=sel_opts, - sampling=sampling, - time_dim=time_dim) + diurnal_anomalies) + -self.create_monthly_mean(input_data.sel({new_dim: 0}, drop=True), self.time_dim, + sel_opts=self.sel_opts, sampling=sampling) + diurnal_anomalies) else: - raise ValueError(f"Cannot handle unkown apriori type: {apriori_type}. Please choose from None, " - f"`zeros` or `residuum_stats`.") + raise ValueError(f"Cannot handle unkown apriori type: {self.apriori_type}. Please choose from None," + f" `zeros` or `residuum_stats`.") # add last residuum to filtered - if minimum_length is None: + if self.minimum_length is None: filtered.append(input_data) else: - filtered.append(input_data.sel({new_dim: slice(-minimum_length, 0)})) - # filtered.append(input_data) + filtered.append(input_data.sel({new_dim: slice(-self.minimum_length, None)})) + self._filtered = filtered self._h = h - self._apriori = apriori_list + self._apriori_list = apriori_list # visualize if self.plot_path is not None: try: - self.PlotClimateFirFilter(self.plot_path, self.plot_data, sampling, plot_name) + self.PlotClimateFirFilter(self.plot_path, self.plot_data, sampling, self.station_name) except Exception as e: logging.info(f"Could not plot climate fir filter due to following reason:\n{e}") @staticmethod - def _minimum_length(order, minimum_length, pos, window): + def _minimum_length(order: list, minimum_length: Union[int, None], pos: int, window: Union[str, tuple]) -> int: next_order = 0 if pos + 1 < len(order): next_order = order[pos + 1] @@ -190,8 +307,16 @@ class ClimateFIRFilter: return next_order if next_order > 0 else None @staticmethod - def create_unity_array(data, time_dim, extend_range=366): - """Create a xr data array filled with ones. time_dim is extended by extend_range days in future and past.""" + def create_monthly_unity_array(data: xr.DataArray, time_dim: str, extend_range: int = 366) -> xr.DataArray: + """ + Create a xarray data array filled with ones with monthly resolution (set on 16th of month). Data is extended by + extend_range days in future and past along time_dim. + + :param data: data to create monthly unity array from, must contain dimension time_dim + :param time_dim: name of temporal dimension + :param extend_range: number of days to extend data (default 366) + :returns: xarray in monthly resolution (centered at 16th day of month) with all values equal to 1 + """ coords = data.coords # extend time_dim by given extend_range days @@ -206,11 +331,28 @@ class ClimateFIRFilter: # loffset is required because resampling uses last day in month as resampling timestamp return new_array.resample({time_dim: "1m"}, loffset=datetime.timedelta(days=-15)).max() - def create_monthly_mean(self, data, sel_opts=None, sampling="1d", time_dim="datetime"): - """Calculate monthly statistics.""" + def create_monthly_mean(self, data: xr.DataArray, time_dim: str, sel_opts: dict = None, sampling: str = "1d") \ + -> xr.DataArray: + """ + Calculate monthly means (12 values) and return a data array with same resolution as given data containing these + monthly mean values. Sampling points are the 16th of each month (this value is equal to the true monthly mean) + and all other values between two points are interpolated linearly. It is possible to apply some pre-selection + to use only a subset of given data using the sel_opts parameter. Only data from this subset are used to + calculate the monthly statistic. + + :param data: data to apply statistical calculation on + :param time_dim: name of temporal axis + :param sel_opts: selection options as dict to select a subset of data (default None). A given sel_opts with + `sel_opts={<time_dim>: "2006"}` forces the method e.g. to derive the monthly means only from data of the + year 2006. + :param sampling: sampling of the returned data (default 1d) + :returns: array in desired resolution containing interpolated monthly values. Months with no valid data are + returned as np.nan which also effects data in the neighbouring months (before / after sampling points which + are the 16th of each month). + """ # create unity xarray in monthly resolution with sampling point in mid of each month - monthly = self.create_unity_array(data, time_dim) + monthly = self.create_monthly_unity_array(data, time_dim) * np.nan # apply selection if given (only use subset for monthly means) if sel_opts is not None: @@ -225,35 +367,68 @@ class ClimateFIRFilter: # transform monthly information into original sampling rate return monthly.resample({time_dim: sampling}).interpolate() - # for month in monthly_mean.month.values: - # loc = (monthly[f"{time_dim}.month"] == month) - # monthly.loc[{time_dim: loc}] = monthly_mean.sel(month=month, drop=True) - # aggregate monthly information (shift by half month, because resample base is last day) - # return monthly.resample({time_dim: "1m"}).max().resample({time_dim: sampling}).interpolate() - @staticmethod - def create_hourly_mean(data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True): - """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True).""" - # can only be used for hourly sampling rate - assert sampling == "1H" - - # create unity xarray in hourly resolution - hourly = xr.ones_like(data) + def _compute_hourly_mean_per_month(data: xr.DataArray, time_dim: str, as_anomaly: bool) -> Dict[int, xr.DataArray]: + """ + Calculate for each hour in each month a separate mean value (12 x 24 values in total). Average is either the + anomaly of a monthly mean state or the raw mean value. - # apply selection if given (only use subset for hourly means) - if sel_opts is not None: - data = data.sel(**sel_opts) + :param data: data to calculate averages on + :param time_dim: name of temporal dimension + :param as_anomaly: indicates whether to calculate means as anomaly of a monthly mean or as raw mean values. + :returns: dictionary containing 12 months each with a 24-valued array (1 entry for each hour) + """ + seasonal_hourly_means = {} + for month in data.groupby(f"{time_dim}.month").groups.keys(): + single_month_data = data.sel({time_dim: (data[f"{time_dim}.month"] == month)}) + hourly_mean = single_month_data.groupby(f"{time_dim}.hour").mean() + if as_anomaly is True: + hourly_mean = hourly_mean - hourly_mean.mean("hour") + seasonal_hourly_means[month] = hourly_mean + return seasonal_hourly_means - # create mean for each hour and replace entries in unity array, calculate anomaly if enabled - hourly_mean = data.groupby(f"{time_dim}.hour").mean() - if as_anomaly is True: - hourly_mean = hourly_mean - hourly_mean.mean("hour") - for hour in hourly_mean.hour.values: - loc = (hourly[f"{time_dim}.hour"] == hour) - hourly.loc[{f"{time_dim}": loc}] = hourly_mean.sel(hour=hour) - return hourly + @staticmethod + def _create_seasonal_cycle_of_single_hour_mean(result_arr: xr.DataArray, means: Dict[int, xr.DataArray], hour: int, + time_dim: str, sampling: str) -> xr.DataArray: + """ + Use monthly means of a given hour to create an array with interpolated values at the indicated hour for each day + of the full time span indicated by given result_arr. + + :param result_arr: template array indicating the full time range and additional dimensions to keep + :param means: dictionary containing 24 hourly averages for each month (12 x 24 values in total) + :param hour: integer of hour of interest + :param time_dim: name of temporal dimension + :param sampling: sampling rate to interpolate + :returns: array with interpolated averages in sampling resolution containing only values for hour of interest + """ + h_coll = xr.ones_like(result_arr) * np.nan + for month in means.keys(): + hourly_mean_single_month = means[month].sel(hour=hour, drop=True) + h_coll = xr.where((h_coll[f"{time_dim}.month"] == month), hourly_mean_single_month, h_coll) + h_coll = h_coll.resample({time_dim: sampling}).interpolate() + h_coll = h_coll.sel({time_dim: (h_coll[f"{time_dim}.hour"] == hour)}) + return h_coll + + def create_seasonal_hourly_mean(self, data: xr.DataArray, time_dim: str, sel_opts: Dict[str, Any] = None, + sampling: str = "1H", as_anomaly: bool = True) -> xr.DataArray: + """ + Compute climatological statistics on hourly base either as raw data or anomalies. For each month, an overall + mean value (only used if requiring anomalies) and the mean of each hour are calculated. The climatological + diurnal cycle is positioned on the 16th of each month and interpolated in between by using a distinct + interpolation for each hour of day. The returned array therefore contains data with a yearly cycle (if anomaly + is not calculated) or data without a yearly cycle (if using anomalies). In both cases, the data have an + amplitude that varies over the year. + + :param data: data to apply this method to + :param time_dim: name of temporal axis + :param sel_opts: specific selection options that are applied before calculation of climatological statistics + (default None) + :param sampling: temporal resolution of data (default "1H") + :param as_anomaly: specify whether to use anomalies or raw data including a seasonal cycle of the mean value + (default: True) + :returns: climatological statistics for given data interpolated with given sampling rate + """ - def create_seasonal_hourly_mean(self, data, sel_opts=None, sampling="1H", time_dim="datetime", as_anomaly=True): """Calculate hourly statistics. Either the absolute value or the anomaly (as_anomaly=True).""" # can only be used for hourly sampling rate assert sampling == "1H" @@ -263,46 +438,44 @@ class ClimateFIRFilter: data = data.sel(**sel_opts) # create unity xarray in monthly resolution with sampling point in mid of each month - monthly = self.create_unity_array(data, time_dim) * np.nan + monthly = self.create_monthly_unity_array(data, time_dim) * np.nan - seasonal_hourly_means = {} - - for month in data.groupby(f"{time_dim}.month").groups.keys(): - # select each month - single_month_data = data.sel({time_dim: (data[f"{time_dim}.month"] == month)}) - hourly_mean = single_month_data.groupby(f"{time_dim}.hour").mean() - if as_anomaly is True: - hourly_mean = hourly_mean - hourly_mean.mean("hour") - seasonal_hourly_means[month] = hourly_mean + # calculate for each hour in each month a separate mean value + seasonal_hourly_means = self._compute_hourly_mean_per_month(data, time_dim, as_anomaly) + # create seasonal cycles of these hourly averages seasonal_coll = [] for hour in data.groupby(f"{time_dim}.hour").groups.keys(): - h_coll = monthly.__deepcopy__() - for month in seasonal_hourly_means.keys(): - hourly_mean_single_month = seasonal_hourly_means[month].sel(hour=hour, drop=True) - h_coll = xr.where((h_coll[f"{time_dim}.month"] == month), - hourly_mean_single_month, - h_coll) - h_coll = h_coll.resample({time_dim: sampling}).interpolate() - h_coll = h_coll.sel({time_dim: (h_coll[f"{time_dim}.hour"] == hour)}) - seasonal_coll.append(h_coll) - hourly = xr.concat(seasonal_coll, time_dim).sortby(time_dim).resample({time_dim: sampling}).interpolate() + mean_single_hour = self._create_seasonal_cycle_of_single_hour_mean(monthly, seasonal_hourly_means, hour, + time_dim, sampling) + seasonal_coll.append(mean_single_hour) + # combine all cycles in a common data array + hourly = xr.concat(seasonal_coll, time_dim).sortby(time_dim).resample({time_dim: sampling}).interpolate() return hourly @staticmethod - def extend_apriori(data, apriori, time_dim, sampling="1d"): + def extend_apriori(data: xr.DataArray, apriori: xr.DataArray, time_dim: str, sampling: str = "1d", + station_name: str = None) -> xr.DataArray: """ - Extend time range of apriori information. - - This method may not working properly if length of apriori is less then one year. + Extend time range of apriori information to span a longer period as data (or at least of equal length). This + method may not working properly if length of apriori contains data from less then one year. + + :param data: data to get time range of which apriori should span in minimum + :param apriori: data that is adjusted. It is assumed that this data varies in the course of the year but is same + for the same day in different years. Otherwise this method will introduce some unintended artefacts in the + apriori data. + :param time_dim: name of temporal dimension + :param sampling: sampling of data (e.g. "1m", "1d", default "1d") + :param station_name: name to use for logging message (default None) + :returns: array which adjusted temporal coverage derived from apriori """ dates = data.coords[time_dim].values td_type = {"1d": "D", "1H": "h"}.get(sampling) # apriori starts after data if dates[0] < apriori.coords[time_dim].values[0]: - logging.debug(f"{data.coords['Stations'].values[0]}: apriori starts after data") + logging.debug(f"{station_name}: apriori starts after data") # add difference in full years date_diff = abs(dates[0] - apriori.coords[time_dim].values[0]).astype("timedelta64[D]") @@ -323,7 +496,7 @@ class ClimateFIRFilter: # apriori ends before data if dates[-1] + np.timedelta64(365, "D") > apriori.coords[time_dim].values[-1]: - logging.debug(f"{data.coords['Stations'].values[0]}: apriori ends before data") + logging.debug(f"{station_name}: apriori ends before data") # add difference in full years + 1 year (because apriori is used as future estimate) date_diff = abs(dates[-1] - apriori.coords[time_dim].values[-1]).astype("timedelta64[D]") @@ -344,24 +517,171 @@ class ClimateFIRFilter: return apriori + def combine_observation_and_apriori(self, data: xr.DataArray, apriori: xr.DataArray, time_dim: str, new_dim: str, + extend_length_history: int, extend_length_future: int, + extend_length_separator: int = 0) -> xr.DataArray: + """ + Combine historical data / observations ("data") and climatological statistics ("apriori"). Historical data are + used on interval [t0 - extend_length_history, t0] and apriori is used on [t0 + 1, t0 + extend_length_future]. If + indicated by the extend_length_seperator, it is possible to shift end of history interval and start of apriori + interval by given number of time steps. + + :param data: historical data for past values, must contain dimensions time_dim and var_dim and might also have + a new_dim dimension + :param apriori: climatological estimate for future values, must contain dimensions time_dim and var_dim, but + can also have dimension new_dim + :param time_dim: name of temporal dimension + :param new_dim: name of new dim on which data is combined along + :param extend_length_history: number of time steps to use from data + :param extend_length_future: number of time steps to use from apriori (minus 1) + :param extend_length_separator: position of last history value to use (default 0), this position indicates the + last value that is used from data (followed by values from apriori). In other words, end of history + interval and start of apriori interval are shifted by this value from t0 (positive or negative). + :returns: combined data array + """ + # check if shift indicated by extend_length_seperator is inside the outer interval limits + # assert (extend_length_separator > -extend_length_history) and (extend_length_separator < extend_length_future) + + # prepare historical data / observation + if new_dim not in data.coords: + history = self._shift_data(data, range(int(-extend_length_history), extend_length_separator + 1), + time_dim, new_dim) + else: + history = data.sel({new_dim: slice(int(-extend_length_history), extend_length_separator)}) + # prepare climatological statistics + if new_dim not in apriori.coords: + future = self._shift_data(apriori, range(extend_length_separator + 1, + extend_length_separator + extend_length_future), + time_dim, new_dim) + else: + future = apriori.sel({new_dim: slice(extend_length_separator + 1, + extend_length_separator + extend_length_future)}) + + # combine historical data [t0-length,t0+sep] and climatological statistics [t0+sep+1,t0+length] + filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left") + return filter_input_data + + def create_visualization(self, filtered, data, filter_input_data, plot_dates, time_dim, new_dim, sampling, + extend_length_history, extend_length_future, minimum_length, h, + variable_name, extend_length_opts=None): # pragma: no cover + plot_data = [] + extend_length_opts = 0 if extend_length_opts is None else extend_length_opts + for viz_date in set(plot_dates).intersection(filtered.coords[time_dim].values): + try: + td_type = {"1d": "D", "1H": "h"}.get(sampling) + t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) + t_plus = viz_date + np.timedelta64(int(extend_length_future + extend_length_opts), td_type) + if new_dim not in data.coords: + tmp_filter_data = self._shift_data(data.sel({time_dim: slice(t_minus, t_plus)}), + range(int(-extend_length_history), + int(extend_length_future + extend_length_opts)), + time_dim, + new_dim).sel({time_dim: viz_date}) + else: + tmp_filter_data = None + valid_range = range(int((len(h) + 1) / 2) if minimum_length is None else minimum_length, + extend_length_opts + 1) + plot_data.append({"t0": viz_date, + "var": variable_name, + "filter_input": filter_input_data.sel({time_dim: viz_date}), + "filter_input_nc": tmp_filter_data, + "valid_range": valid_range, + "time_range": data.sel( + {time_dim: slice(t_minus, t_plus - np.timedelta64(1, td_type))}).coords[ + time_dim].values, + "h": h, + "new_dim": new_dim}) + except: + pass + return plot_data + + @staticmethod + def _get_year_interval(data: xr.DataArray, time_dim: str) -> Tuple[int, int]: + """ + Get year of start and end date of given data. + + :param data: data to extract dates from + :param time_dim: name of temporal axis + :returns: two-element tuple with start and end + """ + start = pd.to_datetime(data.coords[time_dim].min().values).year + end = pd.to_datetime(data.coords[time_dim].max().values).year + return start, end + + @staticmethod + def _calculate_filter_coefficients(window: Union[str, tuple], order: Union[int, tuple], cutoff_high: float, + fs: float) -> np.array: + """ + Calculate filter coefficients for moving window using scipy's signal package for common filter types and local + method firwin_kzf for Kolmogorov Zurbenko filter (kzf). The filter is a low-pass filter. + + :param window: name of the window type which is either a string with the window's name or a tuple containing the + name but also some parameters (e.g. `("kaiser", 5)`) + :param order: order of the filter to create as int or parameters m and k of kzf + :param cutoff_high: cutoff frequency to use for low-pass filter in frequency of fs + :param fs: sampling frequency of time series + """ + if window == "kzf": + h = firwin_kzf(*order) + else: + h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + return h + + @staticmethod + def _trim_data_to_minimum_length(data: xr.DataArray, extend_length_history: int, dim: str, + minimum_length: int = None, extend_length_opts: int = 0) -> xr.DataArray: + """ + Trim data along given axis between either -minimum_length (if given) or -extend_length_history and 0. + + :param data: data to trim + :param extend_length_history: start number for trim range (transformed to negative), only used if parameter + minimum_length is not provided + :param dim: dim to apply trim on + :param minimum_length: start number for trim range (transformed to negative), preferably used (default None) + :returns: trimmed data + """ + #todo update doc strings + if minimum_length is None: + return data.sel({dim: slice(-extend_length_history, extend_length_opts)}, drop=True) + else: + return data.sel({dim: slice(-minimum_length, extend_length_opts)}, drop=True) + + @staticmethod + def _create_full_filter_result_array(template_array: xr.DataArray, result_array: xr.DataArray, new_dim: str, + station_name: str = None) -> xr.DataArray: + """ + Create result filter array with same shape line given template data (should be the original input data before + filtering the data). All gaps are filled by nans. + + :param template_array: this array is used as template for shape and ordering of dims + :param result_array: array with data that are filled into template + :param new_dim: new dimension which is shifted/appended to/at the end (if present or not) + :param station_name: string that is attached to logging (default None) + """ + logging.debug(f"{station_name}: create res_full") + new_coords = {**{k: template_array.coords[k].values for k in template_array.coords if k != new_dim}, + new_dim: result_array.coords[new_dim]} + dims = [*template_array.dims, new_dim] if new_dim not in template_array.dims else template_array.dims + result_array = result_array.transpose(*dims) + return result_array.broadcast_like(xr.DataArray(dims=dims, coords=new_coords)) + @TimeTrackingWrapper def clim_filter(self, data, fs, cutoff_high, order, apriori=None, sel_opts=None, sampling="1d", time_dim="datetime", var_dim="variables", window: Union[str, Tuple] = "hamming", - minimum_length=None, new_dim="window", plot_dates=None): + minimum_length=None, new_dim="window", plot_dates=None, station_name=None, + extend_length_opts: Union[dict, int] = None): - logging.debug(f"{data.coords['Stations'].values[0]}: extend apriori") + logging.debug(f"{station_name}: extend apriori") + extend_opts = extend_length_opts if extend_length_opts is not None else {} # calculate apriori information from data if not given and extend its range if not sufficient long enough if apriori is None: - apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim) + apriori = self.create_monthly_mean(data, time_dim, sel_opts=sel_opts, sampling=sampling) apriori = apriori.astype(data.dtype) - apriori = self.extend_apriori(data, apriori, time_dim, sampling) + apriori = self.extend_apriori(data, apriori, time_dim, sampling, station_name=station_name) # calculate FIR filter coefficients - if window == "kzf": - h = firwin_kzf(*order) - else: - h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + h = self._calculate_filter_coefficients(window, order, cutoff_high, fs) length = len(h) # use filter length if no minimum is given, otherwise use minimum + half filter length for extension @@ -378,30 +698,28 @@ class ClimateFIRFilter: coll = [] for var in reversed(data.coords[var_dim].values): - logging.info(f"{data.coords['Stations'].values[0]} ({var}): sel data") + logging.info(f"{station_name} ({var}): sel data") - _start = pd.to_datetime(data.coords[time_dim].min().values).year - _end = pd.to_datetime(data.coords[time_dim].max().values).year + _start, _end = self._get_year_interval(data, time_dim) + extend_opts_var = extend_opts.get(var, 0) if isinstance(extend_opts, dict) else extend_opts filt_coll = [] for _year in range(_start, _end + 1): - logging.debug(f"{data.coords['Stations'].values[0]} ({var}): year={_year}") + logging.debug(f"{station_name} ({var}): year={_year}") - time_slice = self._create_time_range_extend(_year, sampling, extend_length_history) + # select observations and apriori data + time_slice = self._create_time_range_extend( + _year, sampling, max(extend_length_history, extend_length_future + extend_opts_var)) d = data.sel({var_dim: [var], time_dim: time_slice}) a = apriori.sel({var_dim: [var], time_dim: time_slice}) if len(d.coords[time_dim]) == 0: # no data at all for this year continue # combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length] - if new_dim not in d.coords: - history = self._shift_data(d, range(int(-extend_length_history), 1), time_dim, var_dim, new_dim) - else: - history = d.sel({new_dim: slice(int(-extend_length_history), 0)}) - if new_dim not in a.coords: - future = self._shift_data(a, range(1, extend_length_future), time_dim, var_dim, new_dim) - else: - future = a.sel({new_dim: slice(1, extend_length_future)}) - filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left") + filter_input_data = self.combine_observation_and_apriori(d, a, time_dim, new_dim, extend_length_history, + extend_length_future, + extend_length_separator=extend_opts_var) + + # select only data for current year try: filter_input_data = filter_input_data.sel({time_dim: str(_year)}) except KeyError: # no valid data for this year @@ -409,70 +727,45 @@ class ClimateFIRFilter: if len(filter_input_data.coords[time_dim]) == 0: # no valid data for this year continue - logging.debug(f"{data.coords['Stations'].values[0]} ({var}): start filter convolve") - with TimeTracking(name=f"{data.coords['Stations'].values[0]} ({var}): filter convolve", - logging_level=logging.DEBUG): + # apply filter + logging.debug(f"{station_name} ({var}): start filter convolve") + with TimeTracking(name=f"{station_name} ({var}): filter convolve", logging_level=logging.DEBUG): filt = xr.apply_ufunc(fir_filter_convolve, filter_input_data, - input_core_dims=[[new_dim]], - output_core_dims=[[new_dim]], - vectorize=True, - kwargs={"h": h}, - output_dtypes=[d.dtype]) - - if minimum_length is None: - filt_coll.append(filt.sel({new_dim: slice(-extend_length_history, 0)}, drop=True)) - else: - filt_coll.append(filt.sel({new_dim: slice(-minimum_length, 0)}, drop=True)) + input_core_dims=[[new_dim]], output_core_dims=[[new_dim]], + vectorize=True, kwargs={"h": h}, output_dtypes=[d.dtype]) + + # trim data if required + trimmed = self._trim_data_to_minimum_length(filt, extend_length_history, new_dim, minimum_length, + extend_length_opts=extend_opts_var) + filt_coll.append(trimmed) # visualization - for viz_date in set(plot_dates).intersection(filt.coords[time_dim].values): - try: - td_type = {"1d": "D", "1H": "h"}.get(sampling) - t_minus = viz_date + np.timedelta64(int(-extend_length_history), td_type) - t_plus = viz_date + np.timedelta64(int(extend_length_future), td_type) - if new_dim not in d.coords: - tmp_filter_data = self._shift_data(d.sel({time_dim: slice(t_minus, t_plus)}), - range(int(-extend_length_history), - int(extend_length_future)), - time_dim, var_dim, new_dim).sel({time_dim: viz_date}) - else: - # tmp_filter_data = d.sel({time_dim: viz_date, - # new_dim: slice(int(-extend_length_history), int(extend_length_future))}) - tmp_filter_data = None - valid_range = range(int((length + 1) / 2) if minimum_length is None else minimum_length, 1) - plot_data.append({"t0": viz_date, - "var": var, - "filter_input": filter_input_data.sel({time_dim: viz_date}), - "filter_input_nc": tmp_filter_data, - "valid_range": valid_range, - "time_range": d.sel( - {time_dim: slice(t_minus, t_plus - np.timedelta64(1, td_type))}).coords[ - time_dim].values, - "h": h, - "new_dim": new_dim}) - except: - pass + plot_data.extend(self.create_visualization(filt, d, filter_input_data, plot_dates, time_dim, new_dim, + sampling, extend_length_history, extend_length_future, + minimum_length, h, var, extend_length_opts)) # collect all filter results coll.append(xr.concat(filt_coll, time_dim)) gc.collect() - logging.debug(f"{data.coords['Stations'].values[0]}: concat all variables") - res = xr.concat(coll, var_dim) - # create result array with same shape like input data, gabs are filled by nans - logging.debug(f"{data.coords['Stations'].values[0]}: create res_full") - - new_coords = {**{k: data.coords[k].values for k in data.coords if k != new_dim}, new_dim: res.coords[new_dim]} - dims = [*data.dims, new_dim] if new_dim not in data.dims else data.dims - res = res.transpose(*dims) - # res_full = xr.DataArray(dims=dims, coords=new_coords) - # res_full.loc[res.coords] = res - # res_full.compute() - res_full = res.broadcast_like(xr.DataArray(dims=dims, coords=new_coords)) + # concat all variables + logging.debug(f"{station_name}: concat all variables") + res = xr.concat(coll, var_dim) #todo does this works with different extend_length_opts (is data trimmed or filled with nans, 2nd is target) + + # create result array with same shape like input data, gaps are filled by nans + res_full = self._create_full_filter_result_array(data, res, new_dim, station_name) return res_full, h, apriori, plot_data @staticmethod - def _create_time_range_extend(year, sampling, extend_length): + def _create_time_range_extend(year: int, sampling: str, extend_length: int) -> slice: + """ + Create a slice object for given year plus extend_length in sampling resolution. + + :param year: year to create time range for + :param sampling: sampling of time range + :param extend_length: number of time steps to extend out of given year + :returns: slice object with time range + """ td_type = {"1d": "D", "1H": "h"}.get(sampling) delta = np.timedelta64(extend_length + 1, td_type) start = np.datetime64(f"{year}-01-01") - delta @@ -480,7 +773,14 @@ class ClimateFIRFilter: return slice(start, end) @staticmethod - def _create_tmp_dimension(data): + def _create_tmp_dimension(data: xr.DataArray) -> str: + """ + Create a tmp dimension with name 'window' preferably. If name is already part of one dimensions, tmp dimension + name is multiplied by itself until not present in dims. Method will raise ValueError after 10 tries. + + :param data: data array to create a new tmp dimension for with unique name + :returns: valid name for a tmp dimension (preferably 'window') + """ new_dim = "window" count = 0 while new_dim in data.dims: @@ -490,33 +790,41 @@ class ClimateFIRFilter: raise ValueError("Could not create new dimension.") return new_dim - def _shift_data(self, data, index_value, time_dim, squeeze_dim, new_dim): + def _shift_data(self, data: xr.DataArray, index_value: range, time_dim: str, new_dim: str) -> xr.DataArray: + """ + Shift data multiple times to create history or future along dimension new_dim for each time step. + + :param data: data set to shift + :param index_value: range of integers to span history and/or future + :param time_dim: name of temporal dimension that should be shifted + :param new_dim: name of dimension create by data shift + :return: shifted data + """ coll = [] for i in index_value: coll.append(data.shift({time_dim: -i})) - new_ind = self.create_index_array(new_dim, index_value, squeeze_dim) + new_ind = self.create_index_array(new_dim, index_value) return xr.concat(coll, dim=new_ind) @staticmethod - def create_index_array(index_name: str, index_value, squeeze_dim: str): + def create_index_array(index_name: str, index_value: range): + """ + Create index array from a range object to use as index of a data array. + + :param index_name: name of the index dimension + :param index_value: range of values to use as indexes + :returns: index array for given range of values + """ ind = pd.DataFrame({'val': index_value}, index=index_value) - res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze( - dim=squeeze_dim, - drop=True) + tmp_dim = index_name + "tmp" + res = xr.Dataset.from_dataframe(ind).to_array(tmp_dim).rename({'index': index_name}) + res = res.squeeze(dim=tmp_dim, drop=True) res.name = index_name return res - @property - def filter_coefficients(self): - return self._h - - @property - def filtered_data(self): - return self._filtered - @property def apriori_data(self): - return self._apriori + return self._apriori_list @property def initial_apriori_data(self): @@ -767,7 +1075,8 @@ class KolmogorovZurbenkoFilterMovingWindow(KolmogorovZurbenkoBaseClass): raise ValueError -def firwin_kzf(m, k): +def firwin_kzf(m: int, k: int) -> np.array: + """Calculate weights of window for Kolmogorov Zurbenko filter.""" m, k = int(m), int(k) coef = np.ones(m) for i in range(1, k): @@ -775,10 +1084,10 @@ def firwin_kzf(m, k): for km in range(m): t[km, km:km + coef.size] = coef coef = np.sum(t, axis=0) - return coef / m ** k + return coef / (m ** k) -def omega_null_kzf(m, k, alpha=0.5): +def omega_null_kzf(m: int, k: int, alpha: float = 0.5) -> float: a = np.sqrt(6) / np.pi b = 1 / (2 * np.array(k)) c = 1 - alpha ** b @@ -786,5 +1095,6 @@ def omega_null_kzf(m, k, alpha=0.5): return a * np.sqrt(c / d) -def filter_width_kzf(m, k): +def filter_width_kzf(m: int, k: int) -> int: + """Returns window width of the Kolmorogov Zurbenko filter.""" return k * (m - 1) + 1 diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index 679f5a28fc564d56cd6f3794ee8fe8e1877b2b4c..77be3e2c8edcd43207ac40f67370067826be064b 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -4,7 +4,6 @@ __date__ = '2019-10-21' import inspect import math -import sys import numpy as np import xarray as xr diff --git a/mlair/helpers/testing.py b/mlair/helpers/testing.py index 1fb8012f50dab520df1d154303f727c36bfca418..e727d9b50308d339af79f5c5b82b592af6e91921 100644 --- a/mlair/helpers/testing.py +++ b/mlair/helpers/testing.py @@ -1,10 +1,13 @@ """Helper functions that are used to simplify testing.""" import re from typing import Union, Pattern, List +import inspect import numpy as np import xarray as xr +from mlair.helpers.helpers import remove_items, to_list + class PyTestRegex: r""" @@ -88,6 +91,20 @@ def PyTestAllEqual(check_list: List): return PyTestAllEqualClass(check_list).is_true() +def get_all_args(*args, remove=None, add=None): + res = [] + for a in args: + arg_spec = inspect.getfullargspec(a) + res.extend(arg_spec.args) + res.extend(arg_spec.kwonlyargs) + res = sorted(list(set(res))) + if remove is not None: + res = remove_items(res, remove) + if add is not None: + res += to_list(add) + return res + + def test_nested_equality(obj1, obj2): try: diff --git a/mlair/helpers/time_tracking.py b/mlair/helpers/time_tracking.py index cf366db88adc524e90c2b771bef77c71ee5a9502..5df695b9eee5352152c3189111bacf2fe05a2cb3 100644 --- a/mlair/helpers/time_tracking.py +++ b/mlair/helpers/time_tracking.py @@ -41,7 +41,10 @@ class TimeTrackingWrapper: def __get__(self, instance, cls): """Create bound method object and supply self argument to the decorated method.""" - return types.MethodType(self, instance) + if instance is None: + return self + else: + return types.MethodType(self, instance) class TimeTracking(object): diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py index c91dbec78c4bc990cc9c40c3afb6c506b62928d8..7a91c2269ccd03608bcdbe67a634156f55fde91f 100644 --- a/mlair/plotting/abstract_plot_class.py +++ b/mlair/plotting/abstract_plot_class.py @@ -59,7 +59,7 @@ class AbstractPlotClass: if not os.path.exists(plot_folder): os.makedirs(plot_folder) self.plot_folder = plot_folder - self.plot_name = plot_name.replace("/", "_") + self.plot_name = plot_name.replace("/", "_") if plot_name is not None else plot_name self.resolution = resolution if rc_params is None: rc_params = {'axes.labelsize': 'large', @@ -71,6 +71,9 @@ class AbstractPlotClass: self.rc_params = rc_params self._update_rc_params() + def __del__(self): + plt.close('all') + def _plot(self, *args): """Abstract plot class needs to be implemented in inheritance.""" raise NotImplementedError diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py index 47051a500c29349197f3163861a0fe40cade525d..096163451355cb5011dbb2cf39c48c963d51c03c 100644 --- a/mlair/plotting/data_insight_plotting.py +++ b/mlair/plotting/data_insight_plotting.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd import xarray as xr import matplotlib +# matplotlib.use("Agg") from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates from astropy.timeseries import LombScargle @@ -21,8 +22,6 @@ from mlair.data_handler import DataCollection from mlair.helpers import TimeTrackingWrapper, to_list, remove_items from mlair.plotting.abstract_plot_class import AbstractPlotClass -matplotlib.use("Agg") - @TimeTrackingWrapper class PlotStationMap(AbstractPlotClass): # pragma: no cover @@ -497,7 +496,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover def _get_inputs_targets(gens, dim): k = list(gens.keys())[0] gen = gens[k][0] - inputs = to_list(gen.get_X(as_numpy=False)[0].coords[dim].values.tolist()) + inputs = list(set([y for x in to_list(gen.get_X(as_numpy=False)) for y in x.coords[dim].values.tolist()])) targets = to_list(gen.get_Y(as_numpy=False).coords[dim].values.tolist()) n_branches = len(gen.get_X(as_numpy=False)) return inputs, targets, n_branches @@ -518,7 +517,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover w = min(abs(f(gen).coords[self.window_dim].values)) data = f(gen).sel({self.window_dim: w}) res, _, g_edges = f_proc_hist(data, variables, n_bins, self.variables_dim) - for var in variables: + for var in res.keys(): b = tmp_bins.get(var, []) b.append(res[var]) tmp_bins[var] = b @@ -531,7 +530,7 @@ class PlotDataHistogram(AbstractPlotClass): # pragma: no cover bins = {} edges = {} interval_width = {} - for var in variables: + for var in tmp_bins.keys(): bin_edges = np.linspace(start[var], end[var], n_bins + 1) interval_width[var] = bin_edges[1] - bin_edges[0] for i, e in enumerate(tmp_bins[var]): @@ -868,31 +867,46 @@ def f_proc(var, d_var, f_index, time_dim="datetime", use_last_value=True): # pr def f_proc_2(g, m, pos, variables_dim, time_dim, f_index, use_last_value): # pragma: no cover + + # load lazy data + id_classes = list(filter(lambda x: "id_class" in x, dir(g))) if pos == 0 else ["id_class"] + for id_cls_name in id_classes: + id_cls = getattr(g, id_cls_name) + if hasattr(id_cls, "lazy"): + id_cls.load_lazy() if id_cls.lazy is True else None + raw_data_single = dict() - if hasattr(g.id_class, "lazy"): - g.id_class.load_lazy() if g.id_class.lazy is True else None - if m == 0: - d = g.id_class._data - if d is None: - window_dim = g.id_class.window_dim - history = g.id_class.history - last_entry = history.coords[window_dim][-1] - d1 = history.sel({window_dim: last_entry}, drop=True) - label = g.id_class.label - first_entry = label.coords[window_dim][0] - d2 = label.sel({window_dim: first_entry}, drop=True) - d = (d1, d2) - else: - gd = g.id_class - filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]} - d = (gd.input_data.sel(filter_sel), gd.target_data) - d = d[pos] if isinstance(d, tuple) else d - for var in d[variables_dim].values: - d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim) - var_str, f, pgram = f_proc(var, d_var, f_index, use_last_value=use_last_value) - raw_data_single[var_str] = [(f, pgram)] - if hasattr(g.id_class, "lazy"): - g.id_class.clean_up() if g.id_class.lazy is True else None + for dh in list(filter(lambda x: "unfiltered" not in x, id_classes)): + current_cls = getattr(g, dh) + if m == 0: + d = current_cls._data + if d is None: + window_dim = current_cls.window_dim + history = current_cls.history + last_entry = history.coords[window_dim][-1] + d1 = history.sel({window_dim: last_entry}, drop=True) + label = current_cls.label + first_entry = label.coords[window_dim][0] + d2 = label.sel({window_dim: first_entry}, drop=True) + d = (d1, d2) + else: + filter_sel = {"filter": current_cls.input_data.coords["filter"][m - 1]} + d = (current_cls.input_data.sel(filter_sel), current_cls.target_data) + d = d[pos] if isinstance(d, tuple) else d + for var in d[variables_dim].values: + d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim) + var_str, f, pgram = f_proc(var, d_var, f_index, use_last_value=use_last_value) + if var_str not in raw_data_single.keys(): + raw_data_single[var_str] = [(f, pgram)] + else: + raise KeyError(f"There are multiple pgrams for key {var_str}. Please check your data handler.") + + # perform clean up + for id_cls_name in id_classes: + id_cls = getattr(g, id_cls_name) + if hasattr(id_cls, "lazy"): + id_cls.clean_up() if id_cls.lazy is True else None + return raw_data_single @@ -901,13 +915,14 @@ def f_proc_hist(data, variables, n_bins, variables_dim): # pragma: no cover bin_edges = {} interval_width = {} for var in variables: - d = data.sel({variables_dim: var}).squeeze() if len(data.shape) > 1 else data - res[var], bin_edges[var] = np.histogram(d.values, n_bins) - interval_width[var] = bin_edges[var][1] - bin_edges[var][0] + if var in data.coords[variables_dim]: + d = data.sel({variables_dim: var}).squeeze() if len(data.shape) > 1 else data + res[var], bin_edges[var] = np.histogram(d.values, n_bins) + interval_width[var] = bin_edges[var][1] - bin_edges[var][0] return res, interval_width, bin_edges -class PlotClimateFirFilter(AbstractPlotClass): +class PlotClimateFirFilter(AbstractPlotClass): # pragma: no cover """ Plot climate FIR filter components. @@ -1126,3 +1141,130 @@ class PlotClimateFirFilter(AbstractPlotClass): file = os.path.join(self.plot_folder, "plot_data.pickle") with open(file, "wb") as f: dill.dump(data, f) + + +class PlotFirFilter(AbstractPlotClass): # pragma: no cover + """ + Plot FIR filter components. + + * Creates a separate folder FIR inside the given plot directory. + * For each station up to 4 examples are shown (1 for each season). + * Each filtered component and its residuum is drawn in a separate plot. + * A filter component plot includes the FIR input and the filter response + * A filter residuum plot include the FIR residuum + """ + + def __init__(self, plot_folder, plot_data, name): + + logging.info(f"start PlotFirFilter for ({name})") + + # adjust default plot parameters + rc_params = { + 'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'medium', + 'axes.titlesize': 'large'} + if plot_folder is None: + return + + self.style_dict = { + "original": {"color": "darkgrey", "linestyle": "dashed", "label": "original"}, + "apriori": {"color": "darkgrey", "linestyle": "solid", "label": "estimated future"}, + "clim": {"color": "black", "linestyle": "solid", "label": "clim filter", "linewidth": 2}, + "FIR": {"color": "black", "linestyle": "dashed", "label": "ideal filter", "linewidth": 2}, + "valid_area": {"color": "whitesmoke", "label": "valid area"}, + "t0": {"color": "lightgrey", "lw": 6, "label": "$t_0$"} + } + + plot_folder = os.path.join(os.path.abspath(plot_folder), "FIR") + super().__init__(plot_folder, plot_name=None, rc_params=rc_params) + plot_dict = self._prepare_data(plot_data) + self._name = name + self._plot(plot_dict) + self._store_plot_data(plot_data) + + def _prepare_data(self, data): + """Restructure plot data.""" + plot_dict = {} + for i, o in enumerate(range(len(data))): + plot_data = data[i] + t0 = plot_data.get("t0") + filter_input = plot_data.get("filter_input") + filtered = plot_data.get("filtered") + var_dim = plot_data.get("var_dim") + time_dim = plot_data.get("time_dim") + for var in filtered.coords[var_dim].values: + plot_dict_var = plot_dict.get(var, {}) + plot_dict_t0 = plot_dict_var.get(t0, {}) + plot_dict_order = {"filter_input": filter_input.sel({var_dim: var}, drop=True), + "filtered": filtered.sel({var_dim: var}, drop=True), + "time_dim": time_dim} + plot_dict_t0[i] = plot_dict_order + plot_dict_var[t0] = plot_dict_t0 + plot_dict[var] = plot_dict_var + return plot_dict + + def _plot(self, plot_dict): + for var, viz_date_dict in plot_dict.items(): + for it0, t0 in enumerate(viz_date_dict.keys()): + viz_data = viz_date_dict[t0] + try: + for ifilter in sorted(viz_data.keys()): + data = viz_data[ifilter] + filter_input = data["filter_input"] + filtered = data["filtered"] + time_dim = data["time_dim"] + time_axis = filtered.coords[time_dim].values + fig, ax = plt.subplots() + + # plot backgrounds + self._plot_t0(ax, t0) + + # original data + self._plot_data(ax, time_axis, filter_input, style="original") + + # filter response + self._plot_data(ax, time_axis, filtered, style="FIR") + + # set title, legend, and save plot + ax.set_xlim((time_axis[0], time_axis[-1])) + + plt.title(f"Input of Filter ({str(var)})") + plt.legend() + fig.autofmt_xdate() + plt.tight_layout() + self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}" + self._save() + + # plot residuum + fig, ax = plt.subplots() + self._plot_t0(ax, t0) + self._plot_data(ax, time_axis, filter_input - filtered, style="FIR") + ax.set_xlim((time_axis[0], time_axis[-1])) + plt.title(f"Residuum of Filter ({str(var)})") + plt.legend(loc="upper left") + fig.autofmt_xdate() + plt.tight_layout() + + self.plot_name = f"FIR_{self._name}_{str(var)}_{it0}_{ifilter}_residuum" + self._save() + except Exception as e: + logging.info(f"Could not create plot because of:\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + pass + + def _plot_t0(self, ax, t0): + ax.axvline(t0, **self.style_dict["t0"]) + + def _plot_series(self, ax, time_axis, data, style): + ax.plot(time_axis, data, **self.style_dict[style]) + + def _plot_data(self, ax, time_axis, data, style="original"): + # original data + self._plot_series(ax, time_axis, data.values.flatten(), style=style) + + def _store_plot_data(self, data): + """Store plot data. Could be loaded in a notebook to redraw.""" + file = os.path.join(self.plot_folder, "plot_data.pickle") + with open(file, "wb") as f: + dill.dump(data, f) diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index e6d6de152e42d44f271ba986b6645d2cd36b68d0..748476b814c6c54812df27274f54615bbf08d269 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -641,20 +641,32 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover plot_name = self.plot_name for branch in self._data["branch"].unique(): self._set_title(model_name, branch) - self._plot(branch=branch) self.plot_name = f"{plot_name}_{branch}" - self._save() + try: + self._plot(branch=branch) + self._save() + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: self.plot_name += '_separated' - self._plot(branch=branch, separate_vars=separate_vars) - self._save(bbox_inches='tight') + try: + self._plot(branch=branch, separate_vars=separate_vars) + self._save(bbox_inches='tight') + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") else: - self._plot() - self._save() + try: + self._plot() + self._save() + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: self.plot_name += '_separated' - self._plot(separate_vars=separate_vars) - self._save(bbox_inches='tight') + try: + self._plot(separate_vars=separate_vars) + self._save(bbox_inches='tight') + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") @staticmethod def _set_bootstrap_type(boot_type): @@ -696,11 +708,26 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_') new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', keep=1, as_unique=True) - values = data.values.reshape((*data.shape[:3], len(number_tags), len(new_boot_coords))) - data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, - "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim], - self._boot_dim: data.coords[self._boot_dim]}, - dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) + try: + values = data.values.reshape((*data.shape[:3], len(number_tags), len(new_boot_coords))) + data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, + "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim], + self._boot_dim: data.coords[self._boot_dim]}, + dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) + except ValueError: + data_coll = [] + for nr in number_tags: + filtered_coords = list(filter(lambda x: nr in x.split("_")[0], data.coords[self._x_name].values)) + new_boot_coords = self._return_vars_without_number_tag(filtered_coords, split_by='_', keep=1, + as_unique=True) + sel_data = data.sel({self._x_name: filtered_coords}) + values = sel_data.values.reshape((*data.shape[:3], 1, len(new_boot_coords))) + sel_data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, + "branch": [nr], self._ahead_dim: data.coords[self._ahead_dim], + self._boot_dim: data.coords[self._boot_dim]}, + dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) + data_coll.append(sel_data) + data = xr.concat(data_coll, "branch") else: try: new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', @@ -713,7 +740,7 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover if station_dim not in data.dims: data = data.expand_dims(station_dim) self._number_of_bootstraps = np.unique(data.coords[self._boot_dim].values).shape[0] - return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()) + return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()).dropna() @staticmethod def _get_target_sampling(sampling, pos): @@ -765,9 +792,10 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover def _plot_selected_variables(self, separate_vars: List, branch=None): data = self._data if branch is None else self._data[self._data["branch"] == str(branch)] - self.raise_error_if_separate_vars_do_not_exist(data, separate_vars, self._x_name) + self.raise_error_if_vars_do_not_exist(data, separate_vars, self._x_name, name="separate_vars") all_variables = self._get_unique_values_from_column_of_df(data, self._x_name) remaining_vars = helpers.remove_items(all_variables, separate_vars) + self.raise_error_if_vars_do_not_exist(data, remaining_vars, self._x_name, name="remaining_vars") data_first = self._select_data(df=data, variables=separate_vars, column_name=self._x_name) data_second = self._select_data(df=data, variables=remaining_vars, column_name=self._x_name) @@ -843,9 +871,13 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover selected_data = pd.concat([selected_data, tmp_var], axis=0) return selected_data - def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars, column_name): - if not self._variables_exist_in_df(df=data, variables=separate_vars, column_name=column_name): - raise ValueError(f"At least one entry of `separate_vars' does not exist in `self.data' ") + def raise_error_if_vars_do_not_exist(self, data, vars, column_name, name="separate_vars"): + if len(vars) == 0: + msg = f"No variables are given for `{name}' to check in `self.data' " + raise ValueError(msg) + if not self._variables_exist_in_df(df=data, variables=vars, column_name=column_name): + msg = f"At least one entry of `{name}' does not exist in `self.data' " + raise ValueError(msg) @staticmethod def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List: diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 7f2b3b59b17910ae2667e003a821fbadab755b85..5d687dc2c86cc3681dfaa6b103bd207593e969f6 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -285,13 +285,13 @@ class PostProcessing(RunEnvironment): boot_skill_score = self.calculate_feature_importance_skill_scores(bootstrap_type=boot_type, bootstrap_method=boot_method) self.feature_importance_skill_scores[boot_type][boot_method] = boot_skill_score - except (FileNotFoundError, ValueError): + except (FileNotFoundError, ValueError, OSError): if _iter != 0: - raise RuntimeError(f"calculate_feature_importance ({boot_type}, {boot_type}) was called for the " - f"2nd time. This means, that something internally goes wrong. Please check " - f"for possible errors") - logging.info(f"Could not load all files for feature importance ({boot_type}, {boot_type}), restart " - f"calculate_feature_importance with create_new_bootstraps=True.") + raise RuntimeError(f"calculate_feature_importance ({boot_type}, {boot_method}) was called for " + f"the 2nd time. This means, that something internally goes wrong. Please " + f"check for possible errors.") + logging.info(f"Could not load all files for feature importance ({boot_type}, {boot_method}), " + f"restart calculate_feature_importance with create_new_bootstraps=True.") self.calculate_feature_importance(True, _iter=1, bootstrap_type=boot_type, bootstrap_method=boot_method) @@ -630,6 +630,7 @@ class PostProcessing(RunEnvironment): logging.info(f"start make_prediction for {subset_type}") time_dimension = self.data_store.get("time_dim") window_dim = self.data_store.get("window_dim") + path = self.data_store.get("forecast_path") subset_type = subset.name for i, data in enumerate(subset): input_data = data.get_X() @@ -669,7 +670,6 @@ class PostProcessing(RunEnvironment): **prediction_dict) # save all forecasts locally - path = self.data_store.get("forecast_path") prefix = "forecasts_norm" if normalised is True else "forecasts" file = os.path.join(path, f"{prefix}_{str(data)}_{subset_type}.nc") all_predictions.to_netcdf(file) diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index 92882a897d012a90ea052d9491973b0be83ad3ef..ff29bd213f21616443ac825e575f0efaa17eeace 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -242,7 +242,7 @@ class PreProcessing(RunEnvironment): # start station check collection = DataCollection(name=set_name) valid_stations = [] - kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name) + kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope=set_name) use_multiprocessing = self.data_store.get("use_multiprocessing") tmp_path = self.data_store.get("tmp_path") @@ -300,7 +300,7 @@ class PreProcessing(RunEnvironment): transformation_opts = None if calculate_fresh_transformation is True else self._load_transformation() if transformation_opts is None: logging.info(f"start to calculate transformation parameters.") - kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope="train") + kwargs = self.data_store.create_args_dict(data_handler.requirements(skip_args="station"), scope="train") tmp_path = self.data_store.get_default("tmp_path", default=None) transformation_opts = data_handler.transformation(stations, tmp_path=tmp_path, **kwargs) else: diff --git a/test/test_data_handler/test_data_handler.py b/test/test_data_handler/test_abstract_data_handler.py similarity index 90% rename from test/test_data_handler/test_data_handler.py rename to test/test_data_handler/test_abstract_data_handler.py index 418c7946efe160c9bbfeccff9908a6cf17dec17f..5166717471cb9b98a53cc33462fd65e13d142b5b 100644 --- a/test/test_data_handler/test_data_handler.py +++ b/test/test_data_handler/test_abstract_data_handler.py @@ -4,11 +4,12 @@ import inspect from mlair.data_handler.abstract_data_handler import AbstractDataHandler -class TestDefaultDataHandler: +class TestAbstractDataHandler: def test_required_attributes(self): dh = AbstractDataHandler assert hasattr(dh, "_requirements") + assert hasattr(dh, "_skip_args") assert hasattr(dh, "__init__") assert hasattr(dh, "build") assert hasattr(dh, "requirements") @@ -35,8 +36,12 @@ class TestDefaultDataHandler: def test_own_args(self): dh = AbstractDataHandler() assert isinstance(dh.own_args(), list) - assert len(dh.own_args()) == 0 - assert "self" not in dh.own_args() + assert len(dh.own_args()) == 1 + assert "self" in dh.own_args() + + def test_skip_args(self): + dh = AbstractDataHandler() + assert dh._skip_args == ["self"] def test_transformation(self): assert AbstractDataHandler.transformation() is None diff --git a/test/test_data_handler/test_data_handler_mixed_sampling.py b/test/test_data_handler/test_data_handler_mixed_sampling.py index 7418a435008f06a9016f903fe140b51d0a7c8106..0515278a8ae77880de99b0de4abf7fa85198fe49 100644 --- a/test/test_data_handler/test_data_handler_mixed_sampling.py +++ b/test/test_data_handler/test_data_handler_mixed_sampling.py @@ -2,13 +2,16 @@ __author__ = 'Lukas Leufen' __date__ = '2020-12-10' from mlair.data_handler.data_handler_mixed_sampling import DataHandlerMixedSampling, \ - DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithKzFilter, \ - DataHandlerMixedSamplingWithKzFilterSingleStation, DataHandlerSeparationOfScales, \ - DataHandlerSeparationOfScalesSingleStation, DataHandlerMixedSamplingWithFilterSingleStation -from mlair.data_handler.data_handler_with_filter import DataHandlerKzFilterSingleStation + DataHandlerMixedSamplingSingleStation, DataHandlerMixedSamplingWithFilterSingleStation, \ + DataHandlerMixedSamplingWithFirFilterSingleStation, DataHandlerMixedSamplingWithFirFilter, \ + DataHandlerFirFilterSingleStation, DataHandlerMixedSamplingWithClimateFirFilterSingleStation, \ + DataHandlerMixedSamplingWithClimateFirFilter +from mlair.data_handler.data_handler_with_filter import DataHandlerFilter, DataHandlerFilterSingleStation, \ + DataHandlerClimateFirFilterSingleStation from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation -from mlair.helpers import remove_items +from mlair.data_handler.default_data_handler import DefaultDataHandler from mlair.configuration.defaults import DEFAULT_INTERPOLATION_METHOD +from mlair.helpers.testing import get_all_args import pytest import mock @@ -25,17 +28,23 @@ class TestDataHandlerMixedSampling: assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingSingleStation.__qualname__ def test_requirements(self): + reqs = get_all_args(DefaultDataHandler) obj = object.__new__(DataHandlerMixedSampling) - req = object.__new__(DataHandlerSingleStation) - assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, remove="self") + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerSingleStation, DefaultDataHandler, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs class TestDataHandlerMixedSamplingSingleStation: def test_requirements(self): + reqs = get_all_args(DataHandlerSingleStation) obj = object.__new__(DataHandlerMixedSamplingSingleStation) - req = object.__new__(DataHandlerSingleStation) - assert sorted(obj._requirements) == sorted(remove_items(req.requirements(), "station")) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, remove="self") + assert sorted(obj.requirements()) == reqs @mock.patch("mlair.data_handler.data_handler_single_station.DataHandlerSingleStation.setup_samples") def test_init(self, mock_super_init): @@ -86,45 +95,97 @@ class TestDataHandlerMixedSamplingSingleStation: pass -class TestDataHandlerMixedSamplingWithKzFilter: +class TestDataHandlerMixedSamplingWithFilterSingleStation: - def test_data_handler(self): - obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) - assert obj.data_handler.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__ + def test_requirements(self): - def test_data_handler_transformation(self): - obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) - assert obj.data_handler_transformation.__qualname__ == DataHandlerMixedSamplingWithKzFilterSingleStation.__qualname__ + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self") + assert sorted(obj._requirements) == [] + assert sorted(obj.requirements()) == reqs - def test_requirements(self): - obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) - req1 = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) - req2 = object.__new__(DataHandlerKzFilterSingleStation) - req = list(set(req1.requirements() + req2.requirements())) - assert sorted(obj._requirements) == sorted(remove_items(req, "station")) +class TestDataHandlerMixedSamplingWithFirFilter: -class TestDataHandlerMixedSamplingWithFilterSingleStation: - pass + def test_requirements(self): + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerMixedSamplingWithFirFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFirFilterSingleStation, + remove=["self"]) + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, DataHandlerFilter, + DataHandlerFirFilterSingleStation, DefaultDataHandler, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs -class TestDataHandlerSeparationOfScales: +class TestDataHandlerMixedSamplingWithFirFilterSingleStation: - def test_data_handler(self): - obj = object.__new__(DataHandlerSeparationOfScales) - assert obj.data_handler.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ + def test_requirements(self): + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerMixedSamplingWithFirFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerFirFilterSingleStation, DataHandlerSingleStation, + remove="self") + assert sorted(obj.requirements()) == reqs - def test_data_handler_transformation(self): - obj = object.__new__(DataHandlerSeparationOfScales) - assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ + +class TestDataHandlerMixedSamplingWithClimateFirFilter: def test_requirements(self): - obj = object.__new__(DataHandlerMixedSamplingWithKzFilter) - req1 = object.__new__(DataHandlerMixedSamplingWithFilterSingleStation) - req2 = object.__new__(DataHandlerKzFilterSingleStation) - req = list(set(req1.requirements() + req2.requirements())) - assert sorted(obj._requirements) == sorted(remove_items(req, "station")) + reqs = get_all_args(DataHandlerMixedSamplingWithClimateFirFilter, DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerMixedSamplingWithClimateFirFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation, + DataHandlerSingleStation, DataHandlerFirFilterSingleStation, remove=["self"]) + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerClimateFirFilterSingleStation, + DataHandlerSingleStation, DataHandlerFilter, DataHandlerMixedSamplingWithClimateFirFilter, + DefaultDataHandler, DataHandlerFirFilterSingleStation, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs -class TestDataHandlerSeparationOfScalesSingleStation: - pass +class TestDataHandlerMixedSamplingWithClimateFirFilterSingleStation: + + def test_requirements(self): + reqs = get_all_args(DataHandlerClimateFirFilterSingleStation, DataHandlerFirFilterSingleStation, + DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerMixedSamplingWithClimateFirFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerFirFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, remove="self") + assert sorted(obj.requirements()) == reqs + + +# class TestDataHandlerSeparationOfScales: +# +# def test_data_handler(self): +# obj = object.__new__(DataHandlerSeparationOfScales) +# assert obj.data_handler.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ +# +# def test_data_handler_transformation(self): +# obj = object.__new__(DataHandlerSeparationOfScales) +# assert obj.data_handler_transformation.__qualname__ == DataHandlerSeparationOfScalesSingleStation.__qualname__ +# +# def test_requirements(self): +# reqs = get_all_args(DefaultDataHandler) +# obj = object.__new__(DataHandlerSeparationOfScales) +# assert sorted(obj.own_args()) == reqs +# +# reqs = get_all_args(DataHandlerSeparationOfScalesSingleStation, DataHandlerKzFilterSingleStation, +# DataHandlerMixedSamplingWithKzFilterSingleStation,DataHandlerFilterSingleStation, +# DataHandlerSingleStation, remove=["self", "id_class"]) +# assert sorted(obj._requirements) == reqs +# reqs = get_all_args(DataHandlerSeparationOfScalesSingleStation, DataHandlerKzFilterSingleStation, +# DataHandlerMixedSamplingWithKzFilterSingleStation,DataHandlerFilterSingleStation, +# DataHandlerSingleStation, DefaultDataHandler, remove=["self", "id_class"]) +# assert sorted(obj.requirements()) == reqs + +# +# class TestDataHandlerSeparationOfScalesSingleStation: +# pass + diff --git a/test/test_data_handler/test_data_handler_with_filter.py b/test/test_data_handler/test_data_handler_with_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..b83effd96ec7a496977873af0785a8406fa7114e --- /dev/null +++ b/test/test_data_handler/test_data_handler_with_filter.py @@ -0,0 +1,87 @@ +import pytest + +from mlair.data_handler.data_handler_with_filter import DataHandlerFilter, DataHandlerFilterSingleStation, \ + DataHandlerFirFilter, DataHandlerFirFilterSingleStation, DataHandlerClimateFirFilter, \ + DataHandlerClimateFirFilterSingleStation +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.data_handler.default_data_handler import DefaultDataHandler +from mlair.helpers.testing import get_all_args + + +class TestDataHandlerFilter: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, remove=["self"]) + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFilterSingleStation, DefaultDataHandler, + DataHandlerFilter, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerFilterSingleStation: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFilterSingleStation, DataHandlerSingleStation, remove="self") + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerFirFilter: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerFirFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, + remove=["self"]) + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerSingleStation, DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, + DataHandlerFilter, DefaultDataHandler, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerFirFilterSingleStation: + + def test_requirements(self): + + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation) + obj = object.__new__(DataHandlerFirFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + remove="self") + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerClimateFirFilter: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFilter, DefaultDataHandler) + obj = object.__new__(DataHandlerClimateFirFilter) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, remove="self") + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, DefaultDataHandler, DataHandlerFilter, + remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs + + +class TestDataHandlerClimateFirFilterSingleStation: + + def test_requirements(self): + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation) + obj = object.__new__(DataHandlerClimateFirFilterSingleStation) + assert sorted(obj.own_args()) == reqs + assert sorted(obj._requirements) == [] + reqs = get_all_args(DataHandlerFirFilterSingleStation, DataHandlerFilterSingleStation, DataHandlerSingleStation, + DataHandlerClimateFirFilterSingleStation, remove="self") + assert sorted(obj.requirements()) == reqs diff --git a/test/test_data_handler/test_default_data_handler.py b/test/test_data_handler/test_default_data_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0a5db3d82bf528bfeef321799841588e2d5678 --- /dev/null +++ b/test/test_data_handler/test_default_data_handler.py @@ -0,0 +1,23 @@ +import pytest +from mlair.data_handler.default_data_handler import DefaultDataHandler +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.helpers.testing import get_all_args + + +class TestDefaultDataHandler: + + def test_requirements(self): + reqs = get_all_args(DefaultDataHandler) + obj = object.__new__(DefaultDataHandler) + assert sorted(obj.own_args()) == reqs + reqs = get_all_args(DataHandlerSingleStation, remove="self") + assert sorted(obj._requirements) == reqs + reqs = get_all_args(DefaultDataHandler, DataHandlerSingleStation, remove=["self", "id_class"]) + assert sorted(obj.requirements()) == reqs + reqs = get_all_args(DefaultDataHandler, DataHandlerSingleStation, remove=["self", "id_class", "station"]) + assert sorted(obj.requirements(skip_args="station")) == reqs + + + + + diff --git a/test/test_data_handler/test_default_data_handler_single_station.py b/test/test_data_handler/test_default_data_handler_single_station.py new file mode 100644 index 0000000000000000000000000000000000000000..fea8f9cbddea4cdac350bc9df2c60c8e3a2e7399 --- /dev/null +++ b/test/test_data_handler/test_default_data_handler_single_station.py @@ -0,0 +1,15 @@ +import pytest +from mlair.data_handler.default_data_handler import DefaultDataHandler +from mlair.data_handler.data_handler_single_station import DataHandlerSingleStation +from mlair.helpers.testing import get_all_args +from mlair.helpers import remove_items + + +class TestDataHandlerSingleStation: + + def test_requirements(self): + reqs = get_all_args(DataHandlerSingleStation) + obj = object.__new__(DataHandlerSingleStation) + assert sorted(obj.own_args()) == reqs + assert obj._requirements == [] + assert sorted(obj.requirements()) == remove_items(reqs, "self") diff --git a/test/test_helpers/test_filter.py b/test/test_helpers/test_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..519a36b3438cafd041cc65c43572fc026eced4dd --- /dev/null +++ b/test/test_helpers/test_filter.py @@ -0,0 +1,403 @@ +__author__ = 'Lukas Leufen' +__date__ = '2021-11-18' + +import pytest +import inspect +import numpy as np +import xarray as xr +import pandas as pd + +from mlair.helpers.filter import ClimateFIRFilter, filter_width_kzf, firwin_kzf, omega_null_kzf, fir_filter_convolve + + +class TestClimateFIRFilter: + + @pytest.fixture + def var_dim(self): + return "variables" + + @pytest.fixture + def time_dim(self): + return "datetime" + + @pytest.fixture + def data(self): + pos = np.linspace(0, 4, num=100) + return np.cos(pos * np.pi) + + @pytest.fixture + def xr_array(self, data, time_dim): + start = np.datetime64("2010-01-01 00:00") + time_index = [start + np.timedelta64(h, "h") for h in range(len(data))] + array = xr.DataArray(data.reshape(len(data), 1), dims=[time_dim, "station"], + coords={time_dim: time_index, "station": ["DE266X"]}) + return array + + @pytest.fixture + def xr_array_long(self, data, time_dim): + start = np.datetime64("2010-01-01 00:00") + time_index = [start + np.timedelta64(175 * h, "h") for h in range(len(data))] + array = xr.DataArray(data.reshape(len(data), 1), dims=[time_dim, "station"], + coords={time_dim: time_index, "station": ["DE266X"]}) + return array + + @pytest.fixture + def xr_array_long_with_var(self, data, time_dim, var_dim): + start = np.datetime64("2010-01-01 00:00") + time_index = [start + np.timedelta64(175 * h, "h") for h in range(len(data))] + array = xr.DataArray(data.reshape(*data.shape, 1), dims=[time_dim, "station"], + coords={time_dim: time_index, "station": ["DE266X"]}) + array = array.resample({time_dim: "1H"}).interpolate() + new_data = xr.concat([array, + array + np.sin(np.arange(array.shape[0]) * 2 * np.pi / 24).reshape(*array.shape), + array + np.random.random(size=array.shape), + array * np.random.random(size=array.shape)], + dim=pd.Index(["o3", "temp", "wind", "sun"], name=var_dim)) + return new_data + + def test_combine_observation_and_apriori_no_new_dim(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + apriori = xr.ones_like(xr_array) + res = obj.combine_observation_and_apriori(xr_array, apriori, time_dim, "window", 20, 10) + assert res.coords[time_dim].values[0] == xr_array.coords[time_dim].values[20] + first_entry = res.sel({time_dim: res.coords[time_dim].values[0]}) + assert np.testing.assert_array_equal(first_entry.sel(window=range(-20, 1)).values, xr_array.values[:21]) is None + assert np.testing.assert_array_equal(first_entry.sel(window=range(1, 10)).values, apriori.values[21:30]) is None + + def test_combine_observation_and_apriori_with_new_dim(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + apriori = xr.ones_like(xr_array) + xr_array = obj._shift_data(xr_array, range(-20, 1), time_dim, new_dim="window") + apriori = obj._shift_data(apriori, range(1, 10), time_dim, new_dim="window") + res = obj.combine_observation_and_apriori(xr_array, apriori, time_dim, "window", 10, 10) + assert res.coords[time_dim].values[0] == xr_array.coords[time_dim].values[10] + date_pos = res.coords[time_dim].values[0] + first_entry = res.sel({time_dim: date_pos}) + assert xr.testing.assert_equal(first_entry.sel(window=range(-10, 1)), + xr_array.sel({time_dim: date_pos, "window": range(-10, 1)})) is None + assert xr.testing.assert_equal(first_entry.sel(window=range(1, 10)), apriori.sel({time_dim: date_pos})) is None + + def test_shift_data(self, xr_array, time_dim): + remaining_dims = set(xr_array.dims).difference([time_dim]) + obj = object.__new__(ClimateFIRFilter) + index_values = range(-15, 1) + res = obj._shift_data(xr_array, index_values, time_dim, new_dim="window") + assert len(res.dims) == len(remaining_dims) + 2 + assert len(set(res.dims).difference([time_dim, "window", *remaining_dims])) == 0 + assert np.testing.assert_array_equal(res.coords["window"].values, np.arange(-15, 1)) is None + sel = res.sel({time_dim: res.coords[time_dim].values[15]}) + assert sel.sel(window=-15).values == xr_array.sel({time_dim: xr_array.coords[time_dim].values[0]}).values + assert sel.sel(window=0).values == xr_array.sel({time_dim: xr_array.coords[time_dim].values[15]}).values + + def test_create_index_array(self): + obj = object.__new__(ClimateFIRFilter) + index_name = "test_index_name" + index_values = range(-10, 1) + res = obj.create_index_array(index_name, index_values) + assert len(res.dims) == 1 + assert res.dims[0] == index_name + assert res.shape == (11,) + assert np.testing.assert_array_equal(res.values, np.arange(-10, 1)) is None + + def test_create_tmp_dimension(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + res = obj._create_tmp_dimension(xr_array) + assert res == "window" + xr_array = xr_array.rename({time_dim: "window"}) + res = obj._create_tmp_dimension(xr_array) + assert res == "windowwindow" + xr_array = xr_array.rename({"window": "windowwindow"}) + res = obj._create_tmp_dimension(xr_array) + assert res == "window" + + def test_create_tmp_dimension_iter_limit(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + dim_name = "window" + xr_array = xr_array.rename({time_dim: "window"}) + for i in range(11): + dim_name += dim_name + xr_array = xr_array.expand_dims(dim_name, -1) + with pytest.raises(ValueError) as e: + obj._create_tmp_dimension(xr_array) + assert "Could not create new dimension." in e.value.args[0] + + def test_minimum_length(self): + obj = object.__new__(ClimateFIRFilter) + res = obj._minimum_length([43], 15, 0, "hamming") + assert res == 15 + res = obj._minimum_length([43, 13], 15, 0, ("kaiser", 10)) + assert res == 28 + res = obj._minimum_length([43, 13], 15, 1, "hamming") + assert res == 15 + res = obj._minimum_length([128, 64, 43], None, 0, "hamming") + assert res == 64 + res = obj._minimum_length([43], None, 0, "hamming") + assert res is None + + def test_minimum_length_with_kzf(self): + obj = object.__new__(ClimateFIRFilter) + res = obj._minimum_length([(15, 5), (5, 3)], None, 0, "kzf") + assert res == 13 + + def test_calculate_filter_coefficients(self): + obj = object.__new__(ClimateFIRFilter) + res = obj._calculate_filter_coefficients("hamming", 20, 1, 24) + assert res.shape == (20,) + assert np.testing.assert_almost_equal(res.sum(), 1) is None + res = obj._calculate_filter_coefficients(("kaiser", 10), 20, 1, 24) + assert res.shape == (20,) + assert np.testing.assert_almost_equal(res.sum(), 1) is None + res = obj._calculate_filter_coefficients("kzf", (5, 5), 1, 24) + assert res.shape == (21,) + assert np.testing.assert_almost_equal(res.sum(), 1) is None + + def test_create_monthly_mean(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + res = obj.create_monthly_mean(xr_array_long, time_dim) + assert res.shape == (1462, 1) + assert np.datetime64("2008-12-16") == res.coords[time_dim][0].values + assert np.datetime64("2012-12-16") == res.coords[time_dim][-1].values + mean_jan = xr_array_long[xr_array_long[f"{time_dim}.month"] == 1].mean() + assert res.sel({time_dim: "2009-01-16"}) == mean_jan + mean_jul = xr_array_long[xr_array_long[f"{time_dim}.month"] == 7].mean() + assert res.sel({time_dim: "2009-07-16"}) == mean_jul + assert res.sel({time_dim: "2010-06-15"}) < res.sel({time_dim: "2010-06-16"}) + assert res.sel({time_dim: "2010-06-17"}) > res.sel({time_dim: "2010-06-16"}) + + def test_create_monthly_mean_sampling(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + res = obj.create_monthly_mean(xr_array_long, time_dim, sampling="1m") + assert res.shape == (49, 1) + res = obj.create_monthly_mean(xr_array_long, time_dim, sampling="1H") + assert res.shape == (35065, 1) + mean_jun = xr_array_long[xr_array_long[f"{time_dim}.month"] == 6].mean() + assert res.sel({time_dim: "2010-06-15T00:00:00"}) == mean_jun + assert res.sel({time_dim: "2011-06-15T00:00:00"}) == mean_jun + + def test_create_monthly_mean_sel_opts(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + sel_opts = {time_dim: slice("2010-05", "2010-08")} + res = obj.create_monthly_mean(xr_array_long, time_dim, sel_opts=sel_opts) + assert res.dropna(time_dim)[f"{time_dim}.month"].min() == 5 + assert res.dropna(time_dim)[f"{time_dim}.month"].max() == 8 + mean_jun_2010 = xr_array_long[xr_array_long[f"{time_dim}.month"] == 6].sel({time_dim: "2010"}).mean() + assert res.sel({time_dim: "2010-06-15T00:00:00"}) == mean_jun_2010 + + def test_compute_hourly_mean_per_month(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + res = obj._compute_hourly_mean_per_month(xr_array_long, time_dim, True) + assert len(res.keys()) == 12 + assert 6 in res.keys() + assert np.testing.assert_almost_equal(res[12].mean(), 0) is None + assert np.testing.assert_almost_equal(res[3].mean(), 0) is None + assert res[8].shape == (24, 1) + + def test_compute_hourly_mean_per_month_no_anomaly(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + res = obj._compute_hourly_mean_per_month(xr_array_long, time_dim, False) + assert len(res.keys()) == 12 + assert 9 in res.keys() + assert np.testing.assert_array_less(res[2], res[1]) is None + + def test_create_seasonal_cycle_of_hourly_mean(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + monthly = obj.create_monthly_unity_array(xr_array_long, time_dim) * np.nan + seasonal_hourly_means = obj._compute_hourly_mean_per_month(xr_array_long, time_dim, True) + res = obj._create_seasonal_cycle_of_single_hour_mean(monthly, seasonal_hourly_means, 0, time_dim, "1h") + assert res[f"{time_dim}.hour"].sum() == 0 + assert np.testing.assert_almost_equal(res.sel({time_dim: "2010-12-01"}), res.sel({time_dim: "2011-12-01"})) is None + res = obj._create_seasonal_cycle_of_single_hour_mean(monthly, seasonal_hourly_means, 13, time_dim, "1h") + assert res[f"{time_dim}.hour"].mean() == 13 + assert np.testing.assert_almost_equal(res.sel({time_dim: "2010-12-01"}), res.sel({time_dim: "2011-12-01"})) is None + + def test_create_seasonal_hourly_mean(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + res = obj.create_seasonal_hourly_mean(xr_array_long, time_dim) + assert len(set(res.dims).difference(xr_array_long.dims)) == 0 + assert res.coords[time_dim][0] < xr_array_long.coords[time_dim][0] + assert res.coords[time_dim][-1] > xr_array_long.coords[time_dim][-1] + + def test_create_seasonal_hourly_mean_sel_opts(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_long = xr_array_long.resample({time_dim: "1H"}).interpolate() + sel_opts = {time_dim: slice("2010-05", "2010-08")} + res = obj.create_seasonal_hourly_mean(xr_array_long, time_dim, sel_opts=sel_opts) + assert res.dropna(time_dim)[f"{time_dim}.month"].min() == 5 + assert res.dropna(time_dim)[f"{time_dim}.month"].max() == 8 + + def test_create_unity_array(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + res = obj.create_monthly_unity_array(xr_array, time_dim) + assert np.datetime64("2008-12-16") == res.coords[time_dim][0].values + assert np.datetime64("2011-01-16") == res.coords[time_dim][-1].values + assert res.max() == res.min() + assert res.max() == 1 + assert res.shape == (26, 1) + res = obj.create_monthly_unity_array(xr_array, time_dim, extend_range=0) + assert res.shape == (1, 1) + assert np.datetime64("2010-01-16") == res.coords[time_dim][0].values + res = obj.create_monthly_unity_array(xr_array, time_dim, extend_range=28) + assert res.shape == (3, 1) + + def test_extend_apriori_at_end(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + apriori = xr.ones_like(xr_array_long).sel({time_dim: "2010"}) + res = obj.extend_apriori(xr_array_long, apriori, time_dim) + assert res.coords[time_dim][0] == apriori.coords[time_dim][0] + assert (res.coords[time_dim][-1] - xr_array_long.coords[time_dim][-1]) / np.timedelta64(1, "D") >= 365 + apriori = xr.ones_like(xr_array_long).sel({time_dim: slice("2010", "2011-08")}) + res = obj.extend_apriori(xr_array_long, apriori, time_dim) + assert (res.coords[time_dim][-1] - xr_array_long.coords[time_dim][-1]) / np.timedelta64(1, "D") >= (1.5 * 365) + + def test_extend_apriori_at_start(self, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + apriori = xr.ones_like(xr_array_long).sel({time_dim: "2011"}) + res = obj.extend_apriori(xr_array_long.sel({time_dim: slice("2010", "2010-10")}), apriori, time_dim) + assert (res.coords[time_dim][0] - apriori.coords[time_dim][0]) / np.timedelta64(1, "D") <= -365 * 2 + assert res.coords[time_dim][-1] == apriori.coords[time_dim][-1] + apriori = xr.ones_like(xr_array_long).sel({time_dim: slice("2010-02", "2011")}) + res = obj.extend_apriori(xr_array_long, apriori, time_dim) + assert (res.coords[time_dim][0] - apriori.coords[time_dim][0]) / np.timedelta64(1, "D") <= -365 + + def test_get_year_interval(self, xr_array, xr_array_long, time_dim): + obj = object.__new__(ClimateFIRFilter) + assert obj._get_year_interval(xr_array, time_dim) == (2010, 2010) + assert obj._get_year_interval(xr_array_long, time_dim) == (2010, 2011) + + def test_create_time_range_extend(self): + obj = object.__new__(ClimateFIRFilter) + res = obj._create_time_range_extend(1992, "1d", 10) + assert isinstance(res, slice) + assert res.start == np.datetime64("1991-12-21") + assert res.stop == np.datetime64("1993-01-11") + assert res.step is None + res = obj._create_time_range_extend(1992, "1H", 24) + assert isinstance(res, slice) + assert res.start == np.datetime64("1991-12-30T23:00:00") + assert res.stop == np.datetime64("1993-01-01T01:00:00") + assert res.step is None + + def test_properties(self): + obj = object.__new__(ClimateFIRFilter) + obj._h = [1, 2, 3] + obj._filtered = [4, 5, 63] + obj._apriori_list = [10, 11, 12, 13] + assert obj.filter_coefficients == [1, 2, 3] + assert obj.filtered_data == [4, 5, 63] + assert obj.apriori_data == [10, 11, 12, 13] + assert obj.initial_apriori_data == 10 + + def test_trim_data_to_minimum_length(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array = obj._shift_data(xr_array, range(-20, 1), time_dim, new_dim="window") + res = obj._trim_data_to_minimum_length(xr_array, 5, "window") + assert xr_array.shape == (21, 100, 1) + assert res.shape == (6, 100, 1) + res = obj._trim_data_to_minimum_length(xr_array, 5, "window", 10) + assert res.shape == (11, 100, 1) + res = obj._trim_data_to_minimum_length(xr_array, 30, "window") + assert res.shape == (21, 100, 1) + + def test_create_full_filter_result_array(self, xr_array, time_dim): + obj = object.__new__(ClimateFIRFilter) + xr_array_window = obj._shift_data(xr_array, range(-10, 1), time_dim, new_dim="window").dropna(time_dim) + res = obj._create_full_filter_result_array(xr_array, xr_array_window, "window") + assert res.dims == (*xr_array.dims, "window") + assert res.shape == (*xr_array.shape, 11) + res2 = obj._create_full_filter_result_array(res, xr_array_window, "window") + assert res.dims == res2.dims + assert res.shape == res2.shape + + def test_clim_filter(self, xr_array_long_with_var, time_dim, var_dim): + obj = object.__new__(ClimateFIRFilter) + filter_order = 10*24+1 + res = obj.clim_filter(xr_array_long_with_var, 24, 0.05, 10*24+1, sampling="1H", time_dim=time_dim, var_dim=var_dim) + assert len(res) == 4 + + # check filter data properties + assert res[0].shape == (*xr_array_long_with_var.shape, filter_order + 1) + assert res[0].dims == (*xr_array_long_with_var.dims, "window") + + # check filter properties + assert np.testing.assert_almost_equal( + res[1], obj._calculate_filter_coefficients("hamming", filter_order, 0.05, 24)) is None + + # check apriori + apriori = obj.create_monthly_mean(xr_array_long_with_var, time_dim, sampling="1H") + apriori = apriori.astype(xr_array_long_with_var.dtype) + apriori = obj.extend_apriori(xr_array_long_with_var, apriori, time_dim, "1H") + assert xr.testing.assert_equal(res[2], apriori) is None + + # check plot data format + assert isinstance(res[3], list) + assert isinstance(res[3][0], dict) + keys = {"t0", "var", "filter_input", "filter_input_nc", "valid_range", "time_range", "h", "new_dim"} + assert len(keys.symmetric_difference(res[3][0].keys())) == 0 + + def test_clim_filter_kwargs(self, xr_array_long_with_var, time_dim, var_dim): + obj = object.__new__(ClimateFIRFilter) + filter_order = 10 * 24 + 1 + apriori = obj.create_seasonal_hourly_mean(xr_array_long_with_var, time_dim, sampling="1H", as_anomaly=False) + apriori = apriori.astype(xr_array_long_with_var.dtype) + apriori = obj.extend_apriori(xr_array_long_with_var, apriori, time_dim, "1H") + plot_dates = [xr_array_long_with_var.coords[time_dim][1800].values] + res = obj.clim_filter(xr_array_long_with_var, 24, 0.05, 10 * 24 + 1, sampling="1H", time_dim=time_dim, + var_dim=var_dim, new_dim="total_new_dim", window=("kaiser", 5), minimum_length=1000, + apriori=apriori, plot_dates=plot_dates) + + assert res[0].shape == (*xr_array_long_with_var.shape, 1000 + 1) + assert res[0].dims == (*xr_array_long_with_var.dims, "total_new_dim") + assert np.testing.assert_almost_equal( + res[1], obj._calculate_filter_coefficients(("kaiser", 5), filter_order, 0.05, 24)) is None + assert xr.testing.assert_equal(res[2], apriori) is None + assert len(res[3]) == len(res[0].coords[var_dim]) + + +class TestFirFilterConvolve: + + def test_fir_filter_convolve(self): + data = np.cos(np.linspace(0, 4, num=100) * np.pi) + obj = object.__new__(ClimateFIRFilter) + h = obj._calculate_filter_coefficients("hamming", 21, 0.25, 1) + res = fir_filter_convolve(data, h) + assert res.shape == (100,) + assert np.testing.assert_almost_equal(np.dot(data[40:61], h) / sum(h), res[50]) is None + + +class TestFirwinKzf: + + def test_firwin_kzf(self): + res = firwin_kzf(3, 3) + assert np.testing.assert_almost_equal(res.sum(), 1) is None + assert res.shape == (7,) + assert np.testing.assert_array_equal(res * (3**3), np.array([1, 3, 6, 7, 6, 3, 1])) is None + + +class TestFilterWidthKzf: + + def test_filter_width_kzf(self): + assert filter_width_kzf(15, 5) == 71 + assert filter_width_kzf(3, 5) == 11 + + +class TestOmegaNullKzf: + + def test_omega_null_kzf(self): + assert np.testing.assert_almost_equal(omega_null_kzf(13, 3), 0.01986, decimal=5) is None + assert np.testing.assert_almost_equal(omega_null_kzf(105, 5), 0.00192, decimal=5) is None + assert np.testing.assert_almost_equal(omega_null_kzf(3, 5), 0.07103, decimal=5) is None + + def test_omega_null_kzf_alpha(self): + assert np.testing.assert_almost_equal(omega_null_kzf(3, 3, alpha=1), 0, decimal=1) is None + assert np.testing.assert_almost_equal(omega_null_kzf(3, 3, alpha=0), 0.25989, decimal=5) is None + assert np.testing.assert_almost_equal(omega_null_kzf(3, 3), omega_null_kzf(3, 3, alpha=0.5), decimal=5) is None + + + + + + diff --git a/test/test_helpers/test_testing_helpers.py b/test/test_helpers/test_testing_helpers.py index 83ba0101cd452869af8c56f44432e697d290fa97..bceed646c345d3add4602e67b55da1553eabdbaa 100644 --- a/test/test_helpers/test_testing_helpers.py +++ b/test/test_helpers/test_testing_helpers.py @@ -11,7 +11,8 @@ class TestPyTestRegex: def test_init(self): test_regex = PyTestRegex(r"TestString\d+") - assert isinstance(test_regex._regex, re._pattern_type) + pattern = re._pattern_type if hasattr(re, "_pattern_type") else re.Pattern + assert isinstance(test_regex._regex, pattern) def test_eq(self): assert PyTestRegex(r"TestString\d*") == "TestString"