diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 22a8ca4150c46ad1ce9ba0d005fa643e27906f53..8ad3e1e7ff583bd511d6311f2ab9de886f440fc9 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -215,32 +215,29 @@ class DefaultDataHandler(AbstractDataHandler): raise TypeError(f"Elements of list extreme_values have to be {number.__args__}, but at least element " f"{i} is type {type(i)}") + extremes_X, extremes_Y = None, None for extr_val in sorted(extreme_values): # check if some extreme values are already extracted - if (self._X_extreme is None) or (self._Y_extreme is None): - X = self._X - Y = self._Y + if (extremes_X is None) or (extremes_Y is None): + X, Y = self._X, self._Y + extremes_X, extremes_Y = X, Y else: # one extr value iteration is done already: self.extremes_label is NOT None... - X = self._X_extreme - Y = self._Y_extreme + X, Y = self._X_extreme, self._Y_extreme # extract extremes based on occurrence in labels other_dims = remove_items(list(Y.dims), dim) if extremes_on_right_tail_only: - extreme_idx = (Y > extr_val).any(dim=other_dims) + extreme_idx = (extremes_Y > extr_val).any(dim=other_dims) else: - extreme_idx = xr.concat([(Y < -extr_val).any(dim=other_dims[0]), - (Y > extr_val).any(dim=other_dims[0])], + extreme_idx = xr.concat([(extremes_Y < -extr_val).any(dim=other_dims[0]), + (extremes_Y > extr_val).any(dim=other_dims[0])], dim=other_dims[0]).any(dim=other_dims[0]) sel = extreme_idx[extreme_idx].coords[dim].values - extremes_X = list(map(lambda x: x.sel(**{dim: sel}), X)) + extremes_X = list(map(lambda x: x.sel(**{dim: sel}), extremes_X)) self._add_timedelta(extremes_X, dim, timedelta) - # extremes_X = list(map(lambda x: x.coords[dim].values + np.timedelta64(*timedelta), extremes_X)) - - extremes_Y = Y.sel(**{dim: extreme_idx}) - #extremes_Y.coords[dim].values += np.timedelta64(*timedelta) - self._add_timedelta(extremes_Y, dim, timedelta) + extremes_Y = extremes_Y.sel(**{dim: extreme_idx}) + self._add_timedelta([extremes_Y], dim, timedelta) self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim) self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))