Skip to content
Snippets Groups Projects
Commit 15dc6218 authored by leufen1's avatar leufen1
Browse files

implemented vectorized version of fir filter for faster computation

parent be60dba2
No related branches found
No related tags found
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!295Resolve "data handler FIR filter",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #67231 passed
This commit is part of merge request !318. Comments created here will be created in the context of that merge request.
import gc
import warnings
from typing import Union
from typing import Union, Callable
import logging
import os
......@@ -55,7 +55,7 @@ class FIRFilter:
class ClimateFIRFilter:
def __init__(self, data, fs, order, cutoff, window, time_dim, var_dim, apriori=None, apriori_type=None,
sel_opts=None, plot_path=None, plot_name=None):
sel_opts=None, plot_path=None, plot_name=None, vectorized=True, padlen_factor=0.8):
"""
:param data: data to filter
:param fs: sampling frequency in 1/days -> 1d: fs=1 -> 1H: fs=24
......@@ -82,9 +82,11 @@ class ClimateFIRFilter:
input_data = data.__deepcopy__()
for i in range(len(order)):
# calculate climatological filter
fi, hi, apriori = self.clim_filter(input_data, fs, cutoff[i], order[i], apriori=apriori_list[i],
clim_filter: Callable = {True: self.clim_filter_vectorized, False: self.clim_filter}[vectorized]
fi, hi, apriori = clim_filter(input_data.sel({time_dim: slice("2006")}), fs, cutoff[i], order[i],
apriori=apriori_list[i],
sel_opts=sel_opts, sampling=sampling, time_dim=time_dim, window=window,
var_dim=var_dim, plot_index=i)
var_dim=var_dim, plot_index=i, padlen_factor=padlen_factor)
filtered.append(fi)
h.append(hi)
......@@ -196,7 +198,8 @@ class ClimateFIRFilter:
return apriori
def clim_filter(self, data, fs, cutoff_high, order, apriori=None, padlen=None, sel_opts=None, sampling="1d",
@TimeTrackingWrapper
def clim_filter(self, data, fs, cutoff_high, order, apriori=None, padlen_factor=0.5, sel_opts=None, sampling="1d",
time_dim="datetime", var_dim="variables", window="hamming", plot_index=None):
# calculate apriori information from data if not given and extend its range if not sufficient long enough
......@@ -225,7 +228,7 @@ class ClimateFIRFilter:
tmp_hist = data.sel({time_dim: t_hist})
tmp_fut = apriori.sel({time_dim: t_fut})
tmp_comb = xr.concat([tmp_hist, tmp_fut], dim=time_dim)
_padlen = padlen if padlen is not None else int(0.5 * len(tmp_comb.coords[time_dim]))
_padlen = int(min(padlen_factor, 1) * len(tmp_comb.coords[time_dim]))
tmp_filter, _ = fir_filter(tmp_comb, fs, cutoff_high=cutoff_high, order=order, causal=False,
padlen=_padlen, dim=var_dim, window=window, h=h)
res.loc[{time_dim: t0}] = tmp_filter.loc[{time_dim: t0}]
......@@ -235,16 +238,97 @@ class ClimateFIRFilter:
res.loc[{time_dim: t0}] = np.nan
return res, h, apriori
@TimeTrackingWrapper
def clim_filter_vectorized(self, data, fs, cutoff_high, order, apriori=None, padlen_factor=0.5, sel_opts=None,
sampling="1d", time_dim="datetime", var_dim="variables", window="hamming",
plot_index=None):
# calculate apriori information from data if not given and extend its range if not sufficient long enough
if apriori is None:
apriori = self.create_monthly_mean(data, sel_opts=sel_opts, sampling=sampling, time_dim=time_dim)
apriori = self.extend_apriori(data, apriori, time_dim)
# calculate FIR filter coefficients
h = signal.firwin(order, cutoff_high, pass_zero="lowpass", fs=fs, window=window)
length = len(h)
# create tmp dimension to apply filter, search for unused name
new_dim = self._create_tmp_dimension(data)
# combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length]
history = self._shift_data(data, range(-length, 1), time_dim, var_dim, new_dim)
future = self._shift_data(apriori, range(1, length + 1), time_dim, var_dim, new_dim)
filter_input_data = history.combine_first(future)
# apply vectorized fir filter along the tmp dimension
filt = xr.apply_ufunc(fir_filter_vectorized, filter_input_data,
input_core_dims=[[new_dim]], output_core_dims=[[new_dim]], vectorize=True,
kwargs={"fs": fs, "cutoff_high": cutoff_high, "order": order,
"causal": False, "padlen": int(min(padlen_factor, 1) * length)})
# plot
if self.plot_path is not None:
pos = 720
filter_example = filter_input_data.isel({time_dim: pos})
t0 = filter_example.coords[time_dim].values
t_slice = filter_input_data.isel({time_dim: slice(pos - length, pos + length + 1)}).coords[time_dim].values
self.plot(data, filter_example, var_dim, time_dim, t_slice, t0, plot_index)
# select only values at tmp dimension 0 at each point in time
return filt.sel({new_dim: 0}, drop=True), h, apriori
@staticmethod
def _create_tmp_dimension(data):
new_dim = "window"
count = 0
while new_dim in data.dims:
new_dim += new_dim
count += 1
if count > 10:
raise ValueError("Could not create new dimension.")
return new_dim
def _shift_data(self, data, index_value, time_dim, squeeze_dim, new_dim):
coll = []
for i in index_value:
coll.append(data.shift({time_dim: -i}))
new_ind = self.create_index_array(new_dim, index_value, squeeze_dim)
return xr.concat(coll, dim=new_ind)
@staticmethod
def create_index_array(index_name: str, index_value, squeeze_dim: str):
ind = pd.DataFrame({'val': index_value}, index=index_value)
res = xr.Dataset.from_dataframe(ind).to_array(squeeze_dim).rename({'index': index_name}).squeeze(
dim=squeeze_dim,
drop=True)
res.name = index_name
return res
def plot(self, data, tmp_comb, var_dim, time_dim, time_dim_slice, t0, plot_index):
try:
plot_folder = os.path.join(os.path.abspath(self.plot_path), "climFIR")
if not os.path.exists(plot_folder):
os.makedirs(plot_folder)
for var in data.coords[var_dim]:
data.sel({var_dim: var, time_dim: time_dim_slice}).plot()
tmp_comb.sel({var_dim: var}).plot()
plt.axvline(t0, color="lightgrey")
plt.title(str(var.values))
time_axis = data.sel({var_dim: var, time_dim: time_dim_slice}).coords[time_dim].values
rc_params = {'axes.labelsize': 'large',
'xtick.labelsize': 'large',
'ytick.labelsize': 'large',
'legend.fontsize': 'large',
'axes.titlesize': 'large',
}
plt.rcParams.update(rc_params)
fig, ax = plt.subplots()
ax.axvline(t0, color="lightgrey", lw=6, label="time of interest ($t_0$)")
ax.plot(time_axis, data.sel({var_dim: var, time_dim: time_dim_slice}).values.flatten(),
color="darkgrey", linestyle="--", label="original")
ax.plot(time_axis, tmp_comb.sel({var_dim: var}).values.flatten(), color="black", label="filter input")
# data.sel({var_dim: var, time_dim: time_dim_slice}).plot()
# tmp_comb.sel({var_dim: var}).plot()
plt.title(f"Input of ClimFilter ({str(var.values)})")
plt.legend()
fig.autofmt_xdate()
plt.tight_layout()
plot_name = os.path.join(plot_folder, f"climFIR_{self.plot_name}_{str(var.values)}_{plot_index}.pdf")
plt.savefig(plot_name, dpi=300)
plt.close('all')
......@@ -270,6 +354,8 @@ class ClimateFIRFilter:
def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", dim="variables", h=None,
causal=True, padlen=None):
"""Expects xarray."""
if h is None:
cutoff = []
if cutoff_low is not None:
cutoff += [cutoff_low]
......@@ -283,7 +369,6 @@ def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="ham
filter_type = "lowpass"
else:
raise ValueError("Please provide either cutoff_low or cutoff_high.")
if h is None:
h = signal.firwin(order, cutoff, pass_zero=filter_type, fs=fs, window=window)
filtered = xr.ones_like(data)
for var in data.coords[dim]:
......@@ -297,6 +382,38 @@ def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="ham
return filtered, h
def fir_filter_vectorized(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", h=None, causal=True,
padlen=None):
"""Expects numpy array."""
sel = ~np.isnan(data)
res = np.empty_like(data)
if h is None:
cutoff = []
if cutoff_low is not None:
cutoff += [cutoff_low]
if cutoff_high is not None:
cutoff += [cutoff_high]
if len(cutoff) == 2:
filter_type = "bandpass"
elif len(cutoff) == 1 and cutoff_low is not None:
filter_type = "highpass"
elif len(cutoff) == 1 and cutoff_high is not None:
filter_type = "lowpass"
else:
raise ValueError("Please provide either cutoff_low or cutoff_high.")
h = signal.firwin(order, cutoff, pass_zero=filter_type, fs=fs, window=window)
if causal:
y = signal.lfilter(h, 1., data[sel])
else:
padlen = padlen if padlen is not None else 3 * len(h)
if sum(sel) <= padlen:
y = np.empty_like(data[sel])
else:
y = signal.filtfilt(h, 1., data[sel], padlen=padlen)
res[sel] = y
return res
class KolmogorovZurbenkoBaseClass:
def __init__(self, df, wl, itr, is_child=False, filter_dim="window"):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment