diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py index c020a4134f79fe8a7446f45a791eda9057dc6885..36d6e9ae5394705af4b9fbcfd1d8ff77572642b5 100644 --- a/mlair/data_handler/abstract_data_handler.py +++ b/mlair/data_handler/abstract_data_handler.py @@ -34,7 +34,7 @@ class AbstractDataHandler: return remove_items(list_of_args, ["self"] + list(args)) @classmethod - def store_attributes(cls): + def store_attributes(cls) -> list: """ Let MLAir know that some data should be stored in the data store. This is used for calculations on the train subset that should be applied to validation and test subset. diff --git a/mlair/data_handler/data_handler_with_filter.py b/mlair/data_handler/data_handler_with_filter.py index b9b90d440074a1ee16db84bd7269326e4981957f..0619c74abc6b59e471a318406dec094486cf0966 100644 --- a/mlair/data_handler/data_handler_with_filter.py +++ b/mlair/data_handler/data_handler_with_filter.py @@ -14,7 +14,7 @@ from mlair.data_handler.data_handler_single_station import DataHandlerSingleStat from mlair.data_handler import DefaultDataHandler from mlair.helpers import remove_items, to_list, TimeTrackingWrapper from mlair.helpers.filter import KolmogorovZurbenkoFilterMovingWindow as KZFilter -from mlair.helpers.filter import FIRFilter +from mlair.helpers.filter import FIRFilter, ClimateFIRFilter # define a more general date type for type hinting str_or_list = Union[str, List[str]] @@ -67,7 +67,8 @@ class DataHandlerFilterSingleStation(DataHandlerSingleStation): def make_input_target(self): data, self.meta = self.load_data(self.path, self.station, self.statistics_per_var, self.sampling, - self.station_type, self.network, self.store_data_locally, self.data_origin) + self.station_type, self.network, self.store_data_locally, self.data_origin, + self.start, self.end) self._data = self.interpolate(data, dim=self.time_dim, method=self.interpolation_method, limit=self.interpolation_limit) self.set_inputs_and_targets() @@ -277,3 +278,123 @@ class DataHandlerKzFilter(DefaultDataHandler): data_handler = DataHandlerKzFilterSingleStation data_handler_transformation = DataHandlerKzFilterSingleStation _requirements = data_handler.requirements() + + +class DataHandlerClimateFirFilterSingleStation(DataHandlerFirFilterSingleStation): + """ + Data handler for a single station to be used by a superior data handler. Inputs are FIR filtered. In contrast to + the simple DataHandlerFirFilterSingleStation, this data handler is centered around t0 to have no time delay. For + values in the future (t > t0), this data handler assumes a climatological value for the low pass data and values of + 0 for all residuum components. + + :param apriori: Data to use as apriori information. This should be either a xarray dataarray containing monthly or + any other heuristic to support the clim filter, or a list of such arrays containint heuristics for all residua + in addition. The 2nd can be used together with apriori_type `residuum_stat` which estimates the error of the + residuum when the clim filter should be applied with exogenous parameters. If apriori_type is None/`zeros` data + can be provided, but this is not required in this case. + :param apriori_type: set type of information that is provided to the clim filter. For the first low pass always a + calculated or given statistic is used. For residuum prediction a constant value of zero is assumed if + apriori_type is None or `zeros`, and a climatology of the residuum is used for `residuum_stat`. + """ + + _requirements = remove_items(DataHandlerFirFilterSingleStation.requirements(), "station") + _hash = DataHandlerFirFilterSingleStation._hash + ["apriori_type", "apriori_sel_opts"] + _store_attributes = DataHandlerFirFilterSingleStation.store_attributes() + ["apriori"] + + def __init__(self, *args, apriori=None, apriori_type=None, apriori_sel_opts=None, **kwargs): + self.apriori_type = apriori_type + self.climate_filter_coeff = None # coefficents of the used FIR filter + self.apriori = apriori # exogenous apriori information or None to calculate from data (endogenous) + self.all_apriori = None # collection of all apriori information + self.apriori_sel_opts = apriori_sel_opts # ensure to separate exogenous and endogenous information + super().__init__(*args, **kwargs) + + @TimeTrackingWrapper + def apply_filter(self): + """Apply FIR filter only on inputs.""" + apriori = self.apriori.get(str(self)) if isinstance(self.apriori, dict) else self.apriori + climate_filter = ClimateFIRFilter(self.input_data, self.fs, self.filter_order, self.filter_cutoff_freq, + self.filter_window_type, time_dim=self.time_dim, var_dim=self.target_dim, + apriori_type=self.apriori_type, apriori=apriori, + sel_opts=self.apriori_sel_opts) + self.climate_filter_coeff = climate_filter.filter_coefficients + + # store apriori information: store all if residuum_stat method was used, otherwise just store initial apriori + if self.apriori_type == "residuum_stat": + self.apriori = climate_filter.apriori_data + else: + self.apriori = climate_filter.initial_apriori_data + self.all_apriori = climate_filter.apriori_data + climate_filter_data = climate_filter.filtered_data + + # add unfiltered raw data + if self._add_unfiltered is True: + climate_filter_data.append(self.input_data) + + # create input data with filter index + self.input_data = xr.concat(climate_filter_data, pd.Index(self.create_filter_index(), name=self.filter_dim)) + + # this is just a code snippet to check the results of the filter + # import matplotlib + # matplotlib.use("TkAgg") + # import matplotlib.pyplot as plt + # self.input_data.sel(filter="low", variables="temp", Stations="DEBW107").plot() + # self.input_data.sel(variables="temp", Stations="DEBW107").plot.line(hue="filter") + + def create_filter_index(self) -> pd.Index: + """ + Round cut off periods in days and append 'res' for residuum index. + + Round small numbers (<10) to single decimal, and higher numbers to int. Transform as list of str and append + 'res' for residuum index. Add index unfiltered if the raw / unfiltered data is appended to data in addition. + """ + index = np.round(self.filter_cutoff_period, 1) + f = lambda x: int(np.round(x)) if x >= 10 else np.round(x, 1) + index = list(map(f, index.tolist())) + index = list(map(lambda x: str(x) + "d", index)) + ["res"] + if self._add_unfiltered: + index.append("unfiltered") + return pd.Index(index, name=self.filter_dim) + + def _create_lazy_data(self): + return [self._data, self.meta, self.input_data, self.target_data, self.climate_filter_coeff, + self.apriori, self.all_apriori] + + def _extract_lazy(self, lazy_data): + _data, _meta, _input_data, _target_data, self.climate_filter_coeff, self.apriori, self.all_apriori = lazy_data + DataHandlerSingleStation._extract_lazy(self, (_data, _meta, _input_data, _target_data)) + + @staticmethod + def _prepare_filter_order(filter_order, removed_index, fs): + order = [] + for i, o in enumerate(filter_order): + if i not in removed_index: + fo = int(o * fs) + fo = fo + 1 if fo % 2 == 0 else fo + order.append(fo) + return order + + @staticmethod + def _prepare_filter_cutoff_period(filter_cutoff_period, fs): + """Frequency must be smaller than the sampling frequency fs. Otherwise remove given cutoff period pair.""" + cutoff = [] + removed = [] + for i, period in enumerate(to_list(filter_cutoff_period)): + if period > 2. / fs: + cutoff.append(period) + else: + removed.append(i) + return cutoff, removed + + @staticmethod + def _period_to_freq(cutoff_p): + return [1. / x for x in cutoff_p] + + +class DataHandlerClimateFirFilter(DefaultDataHandler): + """Data handler using climatic adjusted FIR filtered data.""" + + data_handler = DataHandlerClimateFirFilterSingleStation + data_handler_transformation = DataHandlerClimateFirFilterSingleStation + _requirements = data_handler.requirements() + _store_attributes = data_handler.store_attributes() diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index ad2fd12d41dac6902ba8e8a078b52165f1d130c8..4c386885f66133e009ea211b9586cb07c38f28b4 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -3,6 +3,7 @@ import warnings from typing import Union import numpy as np +import pandas as pd from matplotlib import pyplot as plt from scipy import signal import xarray as xr @@ -17,8 +18,8 @@ class FIRFilter: filtered = [] h = [] for i in range(len(order)): - fi, hi = self.apply_fir_filter(data, fs, order[i], cutoff_low=cutoff[i][0], cutoff_high=cutoff[i][1], - window=window, dim=dim) + fi, hi = fir_filter(data, fs, order=order[i], cutoff_low=cutoff[i][0], cutoff_high=cutoff[i][1], + window=window, dim=dim, h=None, causal=True, padlen=None) filtered.append(fi) h.append(hi) @@ -47,31 +48,146 @@ class FIRFilter: # cutoff_high=cutoff[3][1], window=window) # filtered_high = xr.ones_like(station_data) * y_high.reshape(station_data.values.shape) - def apply_fir_filter(self, data, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", dim="variables"): - - # create fir filter coeffs - cutoff = [] - if cutoff_low is not None: - cutoff += [cutoff_low] - if cutoff_high is not None: - cutoff += [cutoff_high] - if len(cutoff) == 2: - filter_type = "bandpass" - elif len(cutoff) == 1 and cutoff_low is not None: - filter_type = "highpass" - elif len(cutoff) == 1 and cutoff_high is not None: - filter_type = "lowpass" - else: - raise ValueError("Please provide either cutoff_low or cutoff_high.") - h = signal.firwin(order, cutoff, pass_zero=filter_type, fs=fs, window=window) - # filter data - filtered = xr.ones_like(data) - for var in data.coords[dim]: - d = data.sel({dim: var}).values.flatten() +class ClimateFIRFilter: + + def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None, + sel_opts=None): + """ + :param data: data to filter + :param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24 + :param order: a tuple with the order of the filter in same ordering like cutoff + :param cutoff: a tuple with the cutoff frequencies (all are applied as low pass) + :param window: window type of the filter (e.g. hamming) + :param time_dim: name of time dimension to apply filter along + :param var_dim: name of variables dimension + :param apriori: apriori information to use for the first low pass. If None, climatology is calculated on the + provided data. + :param apriori_type: type of apriori information to use. Climatology will be used always for first low pass. For + the residuum either the value zero is used (apriori_type is None or "zeros") or a climatology on the + residua is used ("residuum_stats"). + """ + filtered = [] + h = [] + sel_opts = sel_opts if isinstance(sel_opts, dict) else {time_dim: sel_opts} + sampling = {1: "1d", 24: "1H"}.get(int(fs)) + if apriori is None: + apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim) + apriori_list = to_list(apriori) + input_data = data.__deepcopy__() + for i in range(len(order)): + fi, hi, apriori = self.clim_filter(input_data, fs, cutoff[i], order[i], apriori=apriori_list[i], + sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, window=window, + var_dim=var_dim) + filtered.append(fi) + h.append(hi) + input_data = input_data - fi # calculate residuum + if len(apriori_list) <= i + 1: + if apriori_type is None or apriori_type == "zeros": + apriori_list.append(xr.zeros_like(apriori_list[i])) # zero version + elif apriori_type == "residuum_stats": + apriori_list.append(-self.create_monthly_mean(input_data, sel_opts=sel_opts, sampling=sampling, + time_dim=time_dim)) + else: + raise ValueError(f"Cannot handle unkown apriori type: {apriori_type}. Please choose from None, " + f"`zeros` or `residuum_stats`.") + # add residuum to filtered + filtered.append(input_data) + self._filtered = filtered + self._h = h + self._apriori = apriori_list + + @staticmethod + def create_monthly_mean(data, sel_opts=None, sampling="1d", time_dim="datetime"): + monthly = xr.ones_like(data) + if sel_opts is not None: + data = data.sel(**sel_opts) + monthly_mean = data.groupby(f"{time_dim}.month").mean() + for month in monthly_mean.month.values: + loc = (monthly[f"{time_dim}.month"] == month) + monthly.loc[{time_dim: loc}] = monthly_mean.sel(month=month) + return monthly.resample({time_dim: "1m"}).mean().resample({time_dim: sampling}).interpolate() + + def clim_filter(self, data, fs, cutoff_high, order, apriori=None, padlen=None, sel_opts=None, sampling="1d", + time_dim="datetime", var_dim="variables", window="hamming"): + if apriori is None: + apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim) + h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + length = len(h) + dt = data.coords[time_dim].values + res = xr.zeros_like(data) + print("start iteration") + for i in range(0, len(dt)): + t0 = dt[i] + pd_date = pd.to_datetime(t0) + if pd_date.day == 1 and pd_date.month == 1: + print(t0) + try: + i_m = max(0, i - length) + i_p = min(i + length, len(dt) - 2) + t_hist = slice(dt[i_m], dt[i]) + t_fut = slice(dt[i + 1], dt[i_p + 1]) + tmp_hist = data.sel({time_dim: t_hist}) + tmp_fut = apriori.sel({time_dim: t_fut}) + tmp_comb = xr.concat([tmp_hist, tmp_fut], dim=time_dim) + _padlen = padlen if padlen is not None else int(0.5 * len(tmp_comb.coords[time_dim])) + tmp_filter, _ = fir_filter(tmp_comb, fs, cutoff_high=cutoff_high, order=order, causal=False, + padlen=_padlen, dim=var_dim, window=window, h=h) + res.loc[{time_dim: t0}] = tmp_filter.loc[{time_dim: t0}] + except IndexError: + pass + # if i == 720: + # for var in data.coords[var_dim]: + # data.sel({var_dim: var, time_dim: slice(dt[i_m], dt[i_p+1])}).plot() + # tmp_comb.sel({var_dim: var}).plot() + # plt.title(var) + # plt.show() + return res, h, apriori + + @property + def filter_coefficients(self): + return self._h + + @property + def filtered_data(self): + return self._filtered + + @property + def apriori_data(self): + return self._apriori + + @property + def initial_apriori_data(self): + return self.apriori_data[0] + + +def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", dim="variables", h=None, + causal=True, padlen=None): + cutoff = [] + if cutoff_low is not None: + cutoff += [cutoff_low] + if cutoff_high is not None: + cutoff += [cutoff_high] + if len(cutoff) == 2: + filter_type = "bandpass" + elif len(cutoff) == 1 and cutoff_low is not None: + filter_type = "highpass" + elif len(cutoff) == 1 and cutoff_high is not None: + filter_type = "lowpass" + else: + raise ValueError("Please provide either cutoff_low or cutoff_high.") + if h is None: + h = signal.firwin(order, cutoff, pass_zero=filter_type, fs=fs, window=window) + filtered = xr.ones_like(data) + for var in data.coords[dim]: + d = data.sel({dim: var}).values.flatten() + if causal: y = signal.lfilter(h, 1., d) - filtered.loc[{dim: var}] = y - return filtered, h + else: + padlen = padlen if padlen is not None else 3 * len(h) + y = signal.filtfilt(h, 1., d, padlen=padlen) + filtered.loc[{dim: var}] = y + return filtered, h class KolmogorovZurbenkoBaseClass: