diff --git a/mlair/data_handler/bootstraps.py b/mlair/data_handler/bootstraps.py index 0ae88599d94ba6405f280df07c9421d9c72a3e6c..e03881484bfc9b8275ede8a4432072c74643994a 100644 --- a/mlair/data_handler/bootstraps.py +++ b/mlair/data_handler/bootstraps.py @@ -144,8 +144,27 @@ class BootstrapIteratorBranch(BootstrapIterator): super().__init__(*args) def __next__(self): - pass - # TODO: implement here: permute entire branch at once + try: + index = self._collection[self._position] + nboot = self._data.number_of_bootstraps + _X, _Y = self._data.data.get_data(as_numpy=False) + _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) + _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) + for dimension in _X[index].coords[self._dimension].values: + single_variable = _X[index].sel({self._dimension: [dimension]}) + bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) + bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, + dims=single_variable.dims) + _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims) + self._position += 1 + except IndexError: + raise StopIteration() + _X, _Y = self._to_numpy(_X), self._to_numpy(_Y) + return self._reshape(_X), self._reshape(_Y), (None, index) + + @classmethod + def create_collection(cls, data, dim): + return list(range(len(data.get_X(as_numpy=False)))) class ShuffleBootstraps: diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 57b4d6ef7eafa7bb5d82a123cbfbd879fc596027..df8a7d5e88bf3e8cd102a70cb3a2c50d2383c051 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -154,8 +154,6 @@ class PostProcessing(RunEnvironment): went wrong). """ self.bootstrap_skill_scores = {} - bootstrap_type = ["variable", "singleinput"] # Todo: make flexible - bootstrap_method = ["shuffle", "zero_mean"] # Todo: make flexible for boot_type in to_list(bootstrap_type): self.bootstrap_skill_scores[boot_type] = {} for boot_method in to_list(bootstrap_method): @@ -251,7 +249,7 @@ class PostProcessing(RunEnvironment): # calculate skill scores for each variable skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1)) for boot_set in bootstrap_iter: - boot_var = boot_set if isinstance(boot_set, str) else f"{boot_set[0]}_{boot_set[1]}" + boot_var = f"{boot_set[0]}_{boot_set[1]}" if isinstance(boot_set, tuple) else str(boot_set) file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}_{bootstrap_type}_{bootstrap_method}.nc") with xr.open_dataarray(file_name) as da: