diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 5b521f7d8aeb56f781ab0a00387f6025e93c2af2..df5522da34cce03c013e68a608cb52e11212e9cd 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -1,6 +1,6 @@ import gc import warnings -from typing import Union +from typing import Union, Callable import logging import os @@ -55,7 +55,7 @@ class FIRFilter: class ClimateFIRFilter: def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None, - sel_opts=None, plot_path=None, plot_name=None): + sel_opts=None, plot_path=None, plot_name=None, vectorized=True, padlen_factor=0.8): """ :param data: data to filter :param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24 @@ -82,9 +82,11 @@ class ClimateFIRFilter: input_data = data.__deepcopy__() for i in range(len(order)): # calculate climatological filter - fi, hi, apriori = self.clim_filter(input_data, fs, cutoff[i], order[i], apriori=apriori_list[i], - sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, window=window, - var_dim=var_dim, plot_index=i) + clim_filter: Callable = {True: self.clim_filter_vectorized, False: self.clim_filter}[vectorized] + fi, hi, apriori = clim_filter(input_data.sel({time_dim: slice("2006")}), fs, cutoff[i], order[i], + apriori=apriori_list[i], + sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, window=window, + var_dim=var_dim, plot_index=i, padlen_factor=padlen_factor) filtered.append(fi) h.append(hi) @@ -196,7 +198,8 @@ class ClimateFIRFilter: return apriori - def clim_filter(self, data, fs, cutoff_high, order, apriori=None, padlen=None, sel_opts=None, sampling="1d", + @TimeTrackingWrapper + def clim_filter(self, data, fs, cutoff_high, order, apriori=None, padlen_factor=0.5, sel_opts=None, sampling="1d", time_dim="datetime", var_dim="variables", window="hamming", plot_index=None): # calculate apriori information from data if not given and extend its range if not sufficient long enough @@ -225,7 +228,7 @@ class ClimateFIRFilter: tmp_hist = data.sel({time_dim: t_hist}) tmp_fut = apriori.sel({time_dim: t_fut}) tmp_comb = xr.concat([tmp_hist, tmp_fut], dim=time_dim) - _padlen = padlen if padlen is not None else int(0.5 * len(tmp_comb.coords[time_dim])) + _padlen = int(min(padlen_factor, 1) * len(tmp_comb.coords[time_dim])) tmp_filter, _ = fir_filter(tmp_comb, fs, cutoff_high=cutoff_high, order=order, causal=False, padlen=_padlen, dim=var_dim, window=window, h=h) res.loc[{time_dim: t0}] = tmp_filter.loc[{time_dim: t0}] @@ -235,16 +238,97 @@ class ClimateFIRFilter: res.loc[{time_dim: t0}] = np.nan return res, h, apriori + @TimeTrackingWrapper + def clim_filter_vectorized(self, data, fs, cutoff_high, order, apriori=None, padlen_factor=0.5, sel_opts=None, + sampling="1d", time_dim="datetime", var_dim="variables", window="hamming", + plot_index=None): + + # calculate apriori information from data if not given and extend its range if not sufficient long enough + if apriori is None: + apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim) + apriori = self.extend_apriori(data, apriori, time_dim) + + # calculate FIR filter coefficients + h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window) + length = len(h) + + # create tmp dimension to apply filter, search for unused name + new_dim = self._create_tmp_dimension(data) + + # combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length] + history = self._shift_data(data, range(-length, 1), time_dim, var_dim, new_dim) + future = self._shift_data(apriori, range(1, length + 1), time_dim, var_dim, new_dim) + filter_input_data = history.combine_first(future) + + # apply vectorized fir filter along the tmp dimension + filt = xr.apply_ufunc(fir_filter_vectorized, filter_input_data, + input_core_dims=[[new_dim]], output_core_dims=[[new_dim]], vectorize=True, + kwargs={"fs": fs, "cutoff_high": cutoff_high, "order": order, + "causal": False, "padlen": int(min(padlen_factor, 1) * length)}) + + # plot + if self.plot_path is not None: + pos = 720 + filter_example = filter_input_data.isel({time_dim: pos}) + t0 = filter_example.coords[time_dim].values + t_slice = filter_input_data.isel({time_dim: slice(pos - length, pos + length + 1)}).coords[time_dim].values + self.plot(data, filter_example, var_dim, time_dim, t_slice, t0, plot_index) + + # select only values at tmp dimension 0 at each point in time + return filt.sel({new_dim: 0}, drop=True), h, apriori + + @staticmethod + def _create_tmp_dimension(data): + new_dim = "window" + count = 0 + while new_dim in data.dims: + new_dim += new_dim + count += 1 + if count > 10: + raise ValueError("Could not create new dimension.") + return new_dim + + def _shift_data(self, data, index_value, time_dim, squeeze_dim, new_dim): + coll = [] + for i in index_value: + coll.append(data.shift({time_dim: -i})) + new_ind = self.create_index_array(new_dim, index_value, squeeze_dim) + return xr.concat(coll, dim=new_ind) + + @staticmethod + def create_index_array(index_name: str, index_value, squeeze_dim: str): + ind = pd.DataFrame({'val': index_value}, index=index_value) + res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze( + dim=squeeze_dim, + drop=True) + res.name = index_name + return res + def plot(self, data, tmp_comb, var_dim, time_dim, time_dim_slice, t0, plot_index): try: plot_folder = os.path.join(os.path.abspath(self.plot_path), "climFIR") if not os.path.exists(plot_folder): os.makedirs(plot_folder) for var in data.coords[var_dim]: - data.sel({var_dim: var, time_dim: time_dim_slice}).plot() - tmp_comb.sel({var_dim: var}).plot() - plt.axvline(t0, color="lightgrey") - plt.title(str(var.values)) + time_axis = data.sel({var_dim: var, time_dim: time_dim_slice}).coords[time_dim].values + rc_params = {'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'large', + 'axes.titlesize': 'large', + } + plt.rcParams.update(rc_params) + fig, ax = plt.subplots() + ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)") + ax.plot(time_axis, data.sel({var_dim: var, time_dim: time_dim_slice}).values.flatten(), + color="darkgrey", linestyle="--", label="original") + ax.plot(time_axis, tmp_comb.sel({var_dim: var}).values.flatten(), color="black", label="filter input") + # data.sel({var_dim: var, time_dim: time_dim_slice}).plot() + # tmp_comb.sel({var_dim: var}).plot() + plt.title(f"Input of ClimFilter ({str(var.values)})") + plt.legend() + fig.autofmt_xdate() + plt.tight_layout() plot_name = os.path.join(plot_folder, f"climFIR_{self.plot_name}_{str(var.values)}_{plot_index}.pdf") plt.savefig(plot_name, dpi=300) plt.close('all') @@ -270,20 +354,21 @@ class ClimateFIRFilter: def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", dim="variables", h=None, causal=True, padlen=None): - cutoff = [] - if cutoff_low is not None: - cutoff += [cutoff_low] - if cutoff_high is not None: - cutoff += [cutoff_high] - if len(cutoff) == 2: - filter_type = "bandpass" - elif len(cutoff) == 1 and cutoff_low is not None: - filter_type = "highpass" - elif len(cutoff) == 1 and cutoff_high is not None: - filter_type = "lowpass" - else: - raise ValueError("Please provide either cutoff_low or cutoff_high.") + """Expects xarray.""" if h is None: + cutoff = [] + if cutoff_low is not None: + cutoff += [cutoff_low] + if cutoff_high is not None: + cutoff += [cutoff_high] + if len(cutoff) == 2: + filter_type = "bandpass" + elif len(cutoff) == 1 and cutoff_low is not None: + filter_type = "highpass" + elif len(cutoff) == 1 and cutoff_high is not None: + filter_type = "lowpass" + else: + raise ValueError("Please provide either cutoff_low or cutoff_high.") h = signal.firwin(order, cutoff, pass_zero=filter_type, fs=fs, window=window) filtered = xr.ones_like(data) for var in data.coords[dim]: @@ -297,6 +382,38 @@ def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="ham return filtered, h +def fir_filter_vectorized(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", h=None, causal=True, + padlen=None): + """Expects numpy array.""" + sel = ~np.isnan(data) + res = np.empty_like(data) + if h is None: + cutoff = [] + if cutoff_low is not None: + cutoff += [cutoff_low] + if cutoff_high is not None: + cutoff += [cutoff_high] + if len(cutoff) == 2: + filter_type = "bandpass" + elif len(cutoff) == 1 and cutoff_low is not None: + filter_type = "highpass" + elif len(cutoff) == 1 and cutoff_high is not None: + filter_type = "lowpass" + else: + raise ValueError("Please provide either cutoff_low or cutoff_high.") + h = signal.firwin(order, cutoff, pass_zero=filter_type, fs=fs, window=window) + if causal: + y = signal.lfilter(h, 1., data[sel]) + else: + padlen = padlen if padlen is not None else 3 * len(h) + if sum(sel) <= padlen: + y = np.empty_like(data[sel]) + else: + y = signal.filtfilt(h, 1., data[sel], padlen=padlen) + res[sel] = y + return res + + class KolmogorovZurbenkoBaseClass: def __init__(self, df, wl, itr, is_child=False, filter_dim="window"):