diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 0957d41b710fae5c9bd446b6fe1aaa3edcd1a3d9..82f0020fafb0a0a1c27386f1df6ce545f691b63e 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -267,11 +267,11 @@ class ClimateFIRFilter: kwargs = {"fs": fs, "cutoff_high": cutoff_high, "order": order, "causal": False, "padlen": int(min(padlen_factor, 1) * length)} with TimeTracking(): - filt = fir_filter_numpy_vectorized(filter_input_data, var_dim, kwargs) - with TimeTracking(): - filt = xr.apply_ufunc(fir_filter_vectorized, filter_input_data, time_axis, - input_core_dims=[[new_dim], []], output_core_dims=[[new_dim]], vectorize=True, - kwargs=kwargs) + filt = fir_filter_numpy_vectorized(filter_input_data, var_dim, new_dim, kwargs) + # with TimeTracking(): + # filt = xr.apply_ufunc(fir_filter_vectorized, filter_input_data, time_axis, + # input_core_dims=[[new_dim], []], output_core_dims=[[new_dim]], vectorize=True, + # kwargs=kwargs) # plot if self.plot_path is not None: @@ -393,12 +393,13 @@ def fir_filter(data, fs, order=5, cutoff_low=None, cutoff_high=None, window="ham return filtered, h -def fir_filter_numpy_vectorized(filter_input_data, var_dim, kwargs): +def fir_filter_numpy_vectorized(filter_input_data, var_dim, new_dim, kwargs): filt_np = xr.DataArray(np.nan, coords=filter_input_data.coords) for var in filter_input_data.coords[var_dim]: logging.info( f"{filter_input_data.coords['Stations'].values[0]}: {str(var.values)}") # ToDo must be removed, just for debug - a = da.apply_along_axis(fir_filter_vectorized, 2, filter_input_data.sel({var_dim: var}).values, **kwargs) + a = np.apply_along_axis(fir_filter_vectorized, filter_input_data.dims.index(new_dim), + filter_input_data.sel({var_dim: var}).values, **kwargs) filt_np.loc[{var_dim: var}] = a return filt_np