diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 1e864ee8c366ae0155a2dadb7b9619ff9924ca69..3f5ee5f3e8b2d9682fc0f3d1780da7870d64e1fd 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -11,7 +11,7 @@ from matplotlib import pyplot as plt from scipy import signal import xarray as xr -from mlair.helpers import to_list, TimeTrackingWrapper +from mlair.helpers import to_list, TimeTrackingWrapper, TimeTracking class FIRFilter: @@ -258,14 +258,19 @@ class ClimateFIRFilter: # combine historical data / observation [t0-length,t0] and climatological statistics [t0+1,t0+length] history = self._shift_data(data, range(-length, 1), time_dim, var_dim, new_dim) future = self._shift_data(apriori, range(1, length + 1), time_dim, var_dim, new_dim) - filter_input_data = history.combine_first(future) + filter_input_data = xr.concat([history.dropna(time_dim), future], dim=new_dim, join="left") + # filter_input_data = history.combine_first(future) time_axis = filter_input_data.coords["datetime"] # apply vectorized fir filter along the tmp dimension - filt = xr.apply_ufunc(fir_filter_vectorized, filter_input_data, time_axis, - input_core_dims=[[new_dim], []], output_core_dims=[[new_dim]], vectorize=True, - kwargs={"fs": fs, "cutoff_high": cutoff_high, "order": order, - "causal": False, "padlen": int(min(padlen_factor, 1) * length)}) + kwargs = {"fs": fs, "cutoff_high": cutoff_high, "order": order, + "causal": False, "padlen": int(min(padlen_factor, 1) * length)} + with TimeTracking(): + filt = fir_filter_numpy_vectorized(filter_input_data, var_dim, kwargs) + # with TimeTracking(): + # filt = xr.apply_ufunc(fir_filter_vectorized, filter_input_data, time_axis, + # input_core_dims=[[new_dim], []], output_core_dims=[[new_dim]], vectorize=True, + # kwargs=kwargs) # plot if self.plot_path is not None: @@ -383,13 +388,23 @@ def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="ham return filtered, h -def fir_filter_vectorized(data, time_stamp, fs, order=5, cutoff_low=None, cutoff_high=None, window="hamming", h=None, +def fir_filter_numpy_vectorized(filter_input_data, var_dim, kwargs): + filt_np = xr.DataArray(np.nan, coords=filter_input_data.coords) + for var in filter_input_data.coords[var_dim]: + a = np.apply_along_axis(fir_filter_vectorized, 2, filter_input_data.sel({var_dim: var}).values, **kwargs) + filt_np.loc[{var_dim: var}] = a + return filt_np + + +def fir_filter_vectorized(data, time_stamp=None, fs=1, order=5, cutoff_low=None, cutoff_high=None, window="hamming", + h=None, causal=True, padlen=None): """Expects numpy array.""" - pd_date = pd.to_datetime(time_stamp) - if pd_date.day == 1 and pd_date.month in [1, 7]: - logging.info(time_stamp) + if time_stamp is not None: + pd_date = pd.to_datetime(time_stamp) + if pd_date.day == 1 and pd_date.month in [1, 7]: + logging.info(time_stamp) sel = ~np.isnan(data) res = np.empty_like(data) if h is None: