diff --git a/mlair/helpers/filter.py b/mlair/helpers/filter.py index 10c9dd8fd87e25293ff6a74b69c93ce585e1e49a..247c4fc9c7c6d57d721c1d0895cc8c719b1bd4a5 100644 --- a/mlair/helpers/filter.py +++ b/mlair/helpers/filter.py @@ -640,8 +640,7 @@ class ClimateFIRFilter(FIRFilter): @staticmethod def _trim_data_to_minimum_length(data: xr.DataArray, extend_length_history: int, dim: str, - minimum_length: int = None, extend_length_future: int = 0, - offset: int = 0) -> xr.DataArray: + extend_length_future: int = 0) -> xr.DataArray: """ Trim data along given axis between either -minimum_length (if given) or -extend_length_history and extend_length_opts (which is default set to 0). @@ -650,14 +649,9 @@ class ClimateFIRFilter(FIRFilter): :param extend_length_history: start number for trim range (transformed to negative), only used if parameter minimum_length is not provided :param dim: dim to apply trim on - :param minimum_length: start number for trim range (transformed to negative), preferably used (default None) :param extend_length_future: number to use in "future" :returns: trimmed data """ - # if minimum_length is None: - # return data.sel({dim: slice(-extend_length_history, extend_length_future)}, drop=True) - # else: - # return data.sel({dim: slice(-minimum_length + offset, extend_length_future)}, drop=True) return data.sel({dim: slice(-extend_length_history, extend_length_future)}, drop=True) @staticmethod @@ -749,13 +743,12 @@ class ClimateFIRFilter(FIRFilter): # trim data if required valid_range_end = int(filt.coords[new_dim].max() - (length + 1) / 2) + 1 - trimmed = self._trim_data_to_minimum_length(filt, extend_length_history, new_dim, minimum_length + int((next_order + 1) / 2), - extend_length_future=min(extend_length_future, valid_range_end), - offset=offset) + ext_len = min(extend_length_future, valid_range_end) + trimmed = self._trim_data_to_minimum_length(filt, extend_length_history, new_dim, + extend_length_future=ext_len) filt_coll.append(trimmed) - trimmed = self._trim_data_to_minimum_length(filter_input_data, extend_length_history, new_dim, minimum_length + int((next_order + 1) / 2), - extend_length_future=min(extend_length_future, valid_range_end), - offset=offset) + trimmed = self._trim_data_to_minimum_length(filter_input_data, extend_length_history, new_dim, + extend_length_future=ext_len) filt_input_coll.append(trimmed) # visualization diff --git a/test/test_helpers/test_filter.py b/test/test_helpers/test_filter.py index 3c362911dea81f900c377ab19ff27abbb67f7214..e4bfb6890936d13137ebb6dda01a44eed0166ae5 100644 --- a/test/test_helpers/test_filter.py +++ b/test/test_helpers/test_filter.py @@ -297,10 +297,11 @@ class TestClimateFIRFilter: res = obj._trim_data_to_minimum_length(xr_array, 5, "window") assert xr_array.shape == (21, 100, 1) assert res.shape == (6, 100, 1) - res = obj._trim_data_to_minimum_length(xr_array, 5, "window", 10) - assert res.shape == (11, 100, 1) res = obj._trim_data_to_minimum_length(xr_array, 30, "window") assert res.shape == (21, 100, 1) + xr_array = obj._shift_data(xr_array.sel(window=0), range(-20, 5), time_dim, new_dim="window") + res = obj._trim_data_to_minimum_length(xr_array, 5, "window", extend_length_future=2) + assert res.shape == (8, 100, 1) def test_create_full_filter_result_array(self, xr_array, time_dim): obj = object.__new__(ClimateFIRFilter) @@ -320,7 +321,7 @@ class TestClimateFIRFilter: assert len(res) == 5 # check filter data properties - assert res[0].shape == (*xr_array_long_with_var.shape, 24 + 2) + assert res[0].shape == (*xr_array_long_with_var.shape, int(filter_order+1)/2 + 24 + 2) assert res[0].dims == (*xr_array_long_with_var.dims, "window") # check filter properties @@ -350,7 +351,7 @@ class TestClimateFIRFilter: var_dim=var_dim, new_dim="total_new_dim", window=("kaiser", 5), minimum_length=1000, apriori=apriori, plot_dates=plot_dates) - assert res[0].shape == (*xr_array_long_with_var.shape, 1000 + 2) + assert res[0].shape == (*xr_array_long_with_var.shape, int(10 * 24 + 1 + 1) / 2 + 1000 + 2) assert res[0].dims == (*xr_array_long_with_var.dims, "total_new_dim") assert np.testing.assert_almost_equal( res[2], obj._calculate_filter_coefficients(("kaiser", 5), filter_order, 0.05, 24)) is None