From 54c5f1c1bc12c3593335e5674fbebc295bcd806b Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 16 Jul 2021 14:46:50 +0200
Subject: [PATCH] added new bootstrap method "zero mean" and type "single
 input" and "variable"

---
 mlair/data_handler/bootstraps.py           | 172 ++++++++++++++++-----
 mlair/plotting/postprocessing_plotting.py  |  87 ++++++++---
 mlair/run_modules/post_processing.py       | 127 +++++++++------
 test/test_data_handler/old_t_bootstraps.py |   2 +-
 4 files changed, 285 insertions(+), 103 deletions(-)

diff --git a/mlair/data_handler/bootstraps.py b/mlair/data_handler/bootstraps.py
index 68a4bbc4..0ae88599 100644
--- a/mlair/data_handler/bootstraps.py
+++ b/mlair/data_handler/bootstraps.py
@@ -15,69 +15,156 @@ __date__ = '2020-02-07'
 import os
 from collections import Iterator, Iterable
 from itertools import chain
+from typing import Union, List
 
 import numpy as np
 import xarray as xr
 
 from mlair.data_handler.abstract_data_handler import AbstractDataHandler
+from mlair.helpers.helpers import to_list
 
 
 class BootstrapIterator(Iterator):
 
     _position: int = None
 
-    def __init__(self, data: "BootStraps"):
+    def __init__(self, data: "BootStraps", method):
         assert isinstance(data, BootStraps)
         self._data = data
         self._dimension = data.bootstrap_dimension
-        self._collection = self._data.bootstraps()
+        self.boot_dim = "boots"
+        self._method = method
+        self._collection = self.create_collection(self._data.data, self._dimension)
         self._position = 0
 
+    def __next__(self):
+        """Return next element or stop iteration."""
+        raise NotImplementedError
+
+    @classmethod
+    def create_collection(cls, data, dim):
+        raise NotImplementedError
+
+    def _reshape(self, d):
+        if isinstance(d, list):
+            return list(map(lambda x: self._reshape(x), d))
+            # return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d))
+        else:
+            shape = d.shape
+            return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1])
+
+    def _to_numpy(self, d):
+        if isinstance(d, list):
+            return list(map(lambda x: self._to_numpy(x), d))
+        else:
+            return d.values
+
+    def apply_bootstrap_method(self, data: np.ndarray) -> Union[np.ndarray, List[np.ndarray]]:
+        """
+        Apply predefined bootstrap method from given data.
+
+        :param data: data to apply bootstrap method on
+        :return: processed data as numpy array
+        """
+        if isinstance(data, list):
+            return list(map(lambda x: self.apply_bootstrap_method(x.values), data))
+        else:
+            return self._method.apply(data)
+
+
+class BootstrapIteratorSingleInput(BootstrapIterator):
+    _position: int = None
+
+    def __init__(self, *args):
+        super().__init__(*args)
+
     def __next__(self):
         """Return next element or stop iteration."""
         try:
             index, dimension = self._collection[self._position]
             nboot = self._data.number_of_bootstraps
             _X, _Y = self._data.data.get_data(as_numpy=False)
-            _X = list(map(lambda x: x.expand_dims({'boots': range(nboot)}, axis=-1), _X))
-            _Y = _Y.expand_dims({"boots": range(nboot)}, axis=-1)
+            _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X))
+            _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1)
             single_variable = _X[index].sel({self._dimension: [dimension]})
-            shuffled_variable = self.shuffle(single_variable.values)
-            shuffled_data = xr.DataArray(shuffled_variable, coords=single_variable.coords, dims=single_variable.dims)
-            _X[index] = shuffled_data.combine_first(_X[index]).reindex_like(_X[index])
+            bootstrapped_variable = self.apply_bootstrap_method(single_variable.values)
+            bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords,
+                                             dims=single_variable.dims)
+            _X[index] = bootstrapped_data.combine_first(_X[index]).reindex_like(_X[index])
             self._position += 1
         except IndexError:
             raise StopIteration()
         _X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
         return self._reshape(_X), self._reshape(_Y), (index, dimension)
 
-    @staticmethod
-    def _reshape(d):
-        if isinstance(d, list):
-            return list(map(lambda x: np.rollaxis(x, -1, 0).reshape(x.shape[0] * x.shape[-1], *x.shape[1:-1]), d))
-        else:
-            shape = d.shape
-            return np.rollaxis(d, -1, 0).reshape(shape[0] * shape[-1], *shape[1:-1])
+    @classmethod
+    def create_collection(cls, data, dim):
+        l = []
+        for i, x in enumerate(data.get_X(as_numpy=False)):
+            l.append(list(map(lambda y: (i, y), x.indexes[dim])))
+        return list(chain(*l))
 
-    @staticmethod
-    def _to_numpy(d):
-        if isinstance(d, list):
-            return list(map(lambda x: x.values, d))
-        else:
-            return d.values
 
-    @staticmethod
-    def shuffle(data: np.ndarray) -> np.ndarray:
-        """
-        Shuffle randomly from given data (draw elements with replacement).
+class BootstrapIteratorVariable(BootstrapIterator):
 
-        :param data: data to shuffle
-        :return: shuffled data as numpy array
-        """
+    def __init__(self, *args):
+        super().__init__(*args)
+
+    def __next__(self):
+        """Return next element or stop iteration."""
+        try:
+            dimension = self._collection[self._position]
+            nboot = self._data.number_of_bootstraps
+            _X, _Y = self._data.data.get_data(as_numpy=False)
+            _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X))
+            _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1)
+            for index in range(len(_X)):
+                single_variable = _X[index].sel({self._dimension: [dimension]})
+                bootstrapped_variable = self.apply_bootstrap_method(single_variable.values)
+                bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords,
+                                                 dims=single_variable.dims)
+                _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims)
+            self._position += 1
+        except IndexError:
+            raise StopIteration()
+        _X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
+        return self._reshape(_X), self._reshape(_Y), (None, dimension)
+
+    @classmethod
+    def create_collection(cls, data, dim):
+        l = set()
+        for i, x in enumerate(data.get_X(as_numpy=False)):
+            l.update(x.indexes[dim].to_list())
+        return to_list(l)
+
+
+class BootstrapIteratorBranch(BootstrapIterator):
+
+    def __init__(self, *args):
+        super().__init__(*args)
+
+    def __next__(self):
+        pass
+    # TODO: implement here: permute entire branch at once
+
+
+class ShuffleBootstraps:
+
+    @staticmethod
+    def apply(data):
         size = data.shape
         return np.random.choice(data.reshape(-1, ), size=size)
 
 
+class MeanBootstraps:
+
+    def __init__(self, mean):
+        self._mean = mean
+
+    def apply(self, data):
+        return np.ones_like(data) * self._mean
+
+
 class BootStraps(Iterable):
     """
     Main class to perform bootstrap operations.
@@ -89,10 +176,19 @@ class BootStraps(Iterable):
     this variable). The tuple is interesting if X consists on mutliple input streams X_i (e.g. two or more stations)
     because it shows which variable of which input X_i has been bootstrapped. All bootstrap combinations can be
     retrieved by calling the .bootstraps() method. Further more, by calling the .get_orig_prediction() this class
-    imitates according to the set number of bootstraps the original prediction
+    imitates according to the set number of bootstraps the original prediction.
+
+    As bootstrap method, this class can currently make use of the ShuffleBoostraps class that uses drawing with
+    replacement to destroy the variables information by keeping its statistical properties. Use `bootstrap="shuffle"` to
+    call this method. Another method is the zero mean bootstrapping triggered by `bootstrap="zero_mean"` and performed
+    by the MeanBootstraps class. This method destroy the variable's information by a mode collapse to constant value of
+    zero. In case, the variable is normalized with a zero mean, this is equivalent to a mode collapse to the variable's
+    mean value. Statistics in general are not conserved in this case, but the mean value of course. A custom mean value
+    for bootstrapping is currently not supported.
     """
+
     def __init__(self, data: AbstractDataHandler, number_of_bootstraps: int = 10,
-                 bootstrap_dimension: str = "variables"):
+                 bootstrap_dimension: str = "variables", bootstrap_type="singleinput", bootstrap_method="shuffle"):
         """
         Create iterable class to be ready to iter.
 
@@ -100,20 +196,24 @@ class BootStraps(Iterable):
         :param number_of_bootstraps: the number of bootstrap realisations
         """
         self.data = data
-        self.number_of_bootstraps = number_of_bootstraps
+        self.number_of_bootstraps = number_of_bootstraps if bootstrap_method == "shuffle" else 1
         self.bootstrap_dimension = bootstrap_dimension
+        self.bootstrap_method = {"shuffle": ShuffleBootstraps(),
+                                 "zero_mean": MeanBootstraps(mean=0)}.get(
+            bootstrap_method)  # todo adjust number of bootstraps if mean bootstrapping
+        self.BootstrapIterator = {"singleinput": BootstrapIteratorSingleInput,
+                                  "branch": BootstrapIteratorBranch,
+                                  "variable": BootstrapIteratorVariable}.get(bootstrap_type,
+                                                                             BootstrapIteratorSingleInput)
 
     def __iter__(self):
-        return BootstrapIterator(self)
+        return self.BootstrapIterator(self, self.bootstrap_method)
 
     def __len__(self):
-        return len(self.bootstraps())
+        return len(self.BootstrapIterator.create_collection(self.data, self.bootstrap_dimension))
 
     def bootstraps(self):
-        l = []
-        for i, x in enumerate(self.data.get_X(as_numpy=False)):
-            l.append(list(map(lambda y: (i, y), x.indexes['variables'])))
-        return list(chain(*l))
+        return self.BootstrapIterator.create_collection(self.data, self.bootstrap_dimension)
 
     def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN") -> np.ndarray:
         """
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 75249e11..e5080f6e 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -609,7 +609,8 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
     """
 
     def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None,
-                 sampling: str = "daily", ahead_dim: str = "ahead"):
+                 sampling: str = "daily", ahead_dim: str = "ahead", bootstrap_type: str = None,
+                 bootstrap_method: str = None):
         """
         Set attributes and create plot.
 
@@ -619,19 +620,41 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         :param separate_vars: variables to plot separated (default: ['o3'])
         :param sampling: type of sampling rate, should be either hourly or daily (default: "daily")
         :param ahead_dim: name of the ahead dimensions (default: "ahead")
+        :param bootstrap_annotation: additional information to use in the file name (default: None)
         """
-        super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}")
+        annotation = ["_".join([s for s in ["", bootstrap_type, bootstrap_method] if s is not None])][0]
+        super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}{annotation}")
         if separate_vars is None:
             separate_vars = ['o3']
         self._labels = None
         self._x_name = "boot_var"
         self._ahead_dim = ahead_dim
+        self._boot_type = self._set_bootstrap_type(bootstrap_type)
+        self._boot_method = self._set_bootstrap_method(bootstrap_method)
+
+        self._title = f"Bootstrap analysis ({self._boot_method}, {self._boot_type})"
         self._data = self._prepare_data(data, sampling)
-        self._plot()
-        self._save()
-        self.plot_name += '_separated'
-        self._plot(separate_vars=separate_vars)
-        self._save(bbox_inches='tight')
+        if "branch" in self._data.columns:
+            plot_name = self.plot_name
+            for branch in self._data["branch"].unique():
+                self._title = f"Bootstrap analysis ({self._boot_method}, {self._boot_type}, {branch})"
+                self._plot(branch=branch)
+                self.plot_name = f"{plot_name}_{branch}"
+                self._save()
+        else:
+            self._plot()
+            self._save()
+            self.plot_name += '_separated'
+            self._plot(separate_vars=separate_vars)
+            self._save(bbox_inches='tight')
+
+    @staticmethod
+    def _set_bootstrap_type(boot_type):
+        return {"singleinput": "single input"}.get(boot_type, boot_type)
+
+    @staticmethod
+    def _set_bootstrap_method(boot_method):
+        return {"zero_mean": "zero mean", "shuffle": "shuffled"}.get(boot_method, boot_method)
 
     def _prepare_data(self, data: Dict, sampling: str) -> pd.DataFrame:
         """
@@ -643,16 +666,28 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         :param data: dictionary with station names as keys and 2D xarrays as values
         :return: pre-processed data set
         """
-        data = helpers.dict_to_xarray(data, "station").sortby(self._x_name)
-        new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', keep=1)
-        data = data.assign_coords({self._x_name: new_boot_coords})
+        station_dim = "station"
+        data = helpers.dict_to_xarray(data, station_dim).sortby(self._x_name)
+        if self._boot_type == "single input":
+            number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_')
+            new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_',
+                                                                   keep=1, as_unique=True)
+            values = data.values.reshape((data.shape[0], len(new_boot_coords), len(number_tags), data.shape[-1]))
+            data = xr.DataArray(values, coords={station_dim: data.coords["station"], self._x_name: new_boot_coords,
+                                                "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim]},
+                                dims=[station_dim, self._x_name, "branch", self._ahead_dim])
+        else:
+            try:
+                new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_',
+                                                                       keep=1)
+                data = data.assign_coords({self._x_name: new_boot_coords})
+            except NotImplementedError:
+                pass
         _, sampling_letter = self._get_target_sampling(sampling, 1)
-        # sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
-        # sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[1], "")
         self._labels = [str(i) + sampling_letter for i in data.coords[self._ahead_dim].values]
-        if "station" not in data.dims:
-            data = data.expand_dims("station")
-        return data.to_dataframe("data").reset_index(level=[0, 1, 2])
+        if station_dim not in data.dims:
+            data = data.expand_dims(station_dim)
+        return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist())
 
     @staticmethod
     def _get_target_sampling(sampling, pos):
@@ -660,7 +695,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
         return sampling, sampling_letter
 
-    def _return_vars_without_number_tag(self, values, split_by, keep):
+    def _return_vars_without_number_tag(self, values, split_by, keep, as_unique=False):
         arr = np.array([v.split(split_by) for v in values])
         num = arr[:, 0]
         if arr.shape[keep] == 1:  # keep dim has only length 1, no number tags required
@@ -668,9 +703,17 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         new_val = arr[:, keep]
         if self._all_values_are_equal(num, axis=0):
             return new_val
+        elif as_unique is True:
+            return np.unique(new_val)
         else:
             raise NotImplementedError
 
+    @staticmethod
+    def _get_number_tag(values, split_by):
+        arr = np.array([v.split(split_by) for v in values])
+        num = arr[:, 0]
+        return np.unique(num).tolist()
+
     @staticmethod
     def _all_values_are_equal(arr, axis=0):
         if np.all(arr == arr[0], axis=axis):
@@ -687,10 +730,10 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         """
         return "" if score_only else "terms and "
 
-    def _plot(self, separate_vars=None):
+    def _plot(self, branch=None, separate_vars=None):
         """Plot climatological skill score."""
         if separate_vars is None:
-            self._plot_all_variables()
+            self._plot_all_variables(branch)
         else:
             self._plot_selected_variables(separate_vars)
 
@@ -752,6 +795,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
 
         align_yaxis(ax[0], ax[1])
         align_yaxis(ax[0], ax[1])
+        plt.title(self._title)
 
     @staticmethod
     def _select_data(df: pd.DataFrame, variables: List[str], column_name: str) -> pd.DataFrame:
@@ -776,16 +820,17 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         vars_in_df = set(self._get_unique_values_from_column_of_df(df, column_name))
         return set(variables).issubset(vars_in_df)
 
-    def _plot_all_variables(self):
+    def _plot_all_variables(self, branch=None):
         """
 
         """
         fig, ax = plt.subplots()
-        sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=self._data, ax=ax, whis=1., palette="Blues_d",
+        plot_data = self._data if branch is None else self._data[self._data["branch"] == str(branch)]
+        sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1., palette="Blues_d",
                     showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
         ax.axhline(y=0, color="grey", linewidth=.5)
         plt.xticks(rotation=45)
-        ax.set(ylabel=f"skill score", xlabel="", title="summary of all stations")
+        ax.set(ylabel=f"skill score", xlabel="", title=self._title)
         handles, _ = ax.get_legend_handles_labels()
         ax.legend(handles, self._labels)
         plt.tight_layout()
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index f6eec3c8..0c530400 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -103,7 +103,7 @@ class PostProcessing(RunEnvironment):
         if self.data_store.get("evaluate_bootstraps", "postprocessing"):
             with TimeTracking(name="calculate bootstraps"):
                 create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing")
-                self.bootstrap_postprocessing(create_new_bootstraps)
+                self.bootstrap_postprocessing(create_new_bootstraps)  # todo: make flexible and add boot method and type
 
         # skill scores and error metrics
         with TimeTracking(name="calculate skill scores"):
@@ -136,7 +136,8 @@ class PostProcessing(RunEnvironment):
                 continue
         return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None
 
-    def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0) -> None:
+    def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type=None,
+                                 bootstrap_method=None) -> None:
         """
         Calculate skill scores of bootstrapped data.
 
@@ -149,18 +150,28 @@ class PostProcessing(RunEnvironment):
         :param _iter: internal counter to reduce unnecessary recursive calls (maximum number is 2, otherwise something
             went wrong).
         """
-        try:
-            if create_new_bootstraps:
-                self.create_bootstrap_forecast()
-            self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
-        except FileNotFoundError:
-            if _iter != 0:
-                raise RuntimeError("bootstrap_postprocessing is called for the 2nd time. This means, that calling"
-                                   "manually the reason for the failure.")
-            logging.info("Couldn't load all files, restart bootstrap postprocessing with create_new_bootstraps=True.")
-            self.bootstrap_postprocessing(True, _iter=1)
-
-    def create_bootstrap_forecast(self) -> None:
+        self.bootstrap_skill_scores = {}
+        bootstrap_type = ["variable", "singleinput"]  # Todo: make flexible
+        bootstrap_method = ["shuffle", "zero_mean"]  # Todo: make flexible
+        for boot_type in to_list(bootstrap_type):
+            self.bootstrap_skill_scores[boot_type] = {}
+            for boot_method in to_list(bootstrap_method):
+                try:
+                    if create_new_bootstraps:
+                        self.create_bootstrap_forecast(bootstrap_type=boot_type, bootstrap_method=boot_method)
+                    boot_skill_score = self.calculate_bootstrap_skill_scores(bootstrap_type=boot_type,
+                                                                             bootstrap_method=boot_method)
+                    self.bootstrap_skill_scores[boot_type][boot_method] = boot_skill_score
+                except FileNotFoundError:
+                    if _iter != 0:
+                        raise RuntimeError(f"bootstrap_postprocessing ({boot_type}, {boot_type}) was called for the 2nd"
+                                           f" time. This means, that something internally goes wrong. Please check for "
+                                           f"possible errors")
+                    logging.info(f"Could not load all files for bootstrapping ({boot_type}, {boot_type}), restart "
+                                 f"bootstrap postprocessing with create_new_bootstraps=True.")
+                    self.bootstrap_postprocessing(True, _iter=1, bootstrap_type=boot_type, bootstrap_method=boot_method)
+
+    def create_bootstrap_forecast(self, bootstrap_type, bootstrap_method) -> None:
         """
         Create bootstrapped predictions for all stations and variables.
 
@@ -168,16 +179,16 @@ class PostProcessing(RunEnvironment):
         `bootstraps_labels_{station}.nc`.
         """
         # forecast
-        with TimeTracking(name=inspect.stack()[0].function):
+        with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"):
             # extract all requirements from data store
-            bootstrap_path = self.data_store.get("bootstrap_path")
             forecast_path = self.data_store.get("forecast_path")
             number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
             dims = ["index", self.ahead_dim, "type"]
             for station in self.test_data:
-                logging.info(str(station))
+                # logging.info(str(station))
                 X, Y = None, None
-                bootstraps = BootStraps(station, number_of_bootstraps)
+                bootstraps = BootStraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type,
+                                        bootstrap_method=bootstrap_method)
                 for boot in bootstraps:
                     X, Y, (index, dimension) = boot
                     # make bootstrap predictions
@@ -188,18 +199,19 @@ class PostProcessing(RunEnvironment):
                     bootstrap_predictions = np.expand_dims(bootstrap_predictions, axis=-1)
                     shape = bootstrap_predictions.shape
                     coords = (range(shape[0]), range(1, shape[1] + 1))
-                    var = f"{index}_{dimension}"
+                    var = f"{index}_{dimension}" if index is not None else str(dimension)
                     tmp = xr.DataArray(bootstrap_predictions, coords=(*coords, [var]), dims=dims)
-                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{var}.nc")
+                    file_name = os.path.join(forecast_path,
+                                             f"bootstraps_{station}_{var}_{bootstrap_type}_{bootstrap_method}.nc")
                     tmp.to_netcdf(file_name)
                 else:
                     # store also true labels for each station
                     labels = np.expand_dims(Y, axis=-1)
-                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_labels.nc")
+                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc")
                     labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=dims)
                     labels.to_netcdf(file_name)
 
-    def calculate_bootstrap_skill_scores(self) -> Dict[str, xr.DataArray]:
+    def calculate_bootstrap_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]:
         """
         Calculate skill score of bootstrapped variables.
 
@@ -209,53 +221,67 @@ class PostProcessing(RunEnvironment):
 
         :return: The result dictionary with station-wise skill scores
         """
-        with TimeTracking(name=inspect.stack()[0].function):
+        with TimeTracking(name=f"{inspect.stack()[0].function} ({bootstrap_type}, {bootstrap_method})"):
             # extract all requirements from data store
-            bootstrap_path = self.data_store.get("bootstrap_path")
             forecast_path = self.data_store.get("forecast_path")
             number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
             forecast_file = f"forecasts_norm_%s_test.nc"
-            bootstraps = BootStraps(self.test_data[0], number_of_bootstraps).bootstraps()
+
+            bootstraps = BootStraps(self.test_data[0], number_of_bootstraps, bootstrap_type=bootstrap_type,
+                                    bootstrap_method=bootstrap_method)
+            number_of_bootstraps = bootstraps.number_of_bootstraps
+            bootstrap_iter = bootstraps.bootstraps()
             skill_scores = statistics.SkillScores(None)
             score = {}
             for station in self.test_data:
-                logging.info(station)
-
                 # get station labels
-                file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_labels.nc")
-                labels = xr.open_dataarray(file_name)
+                file_name = os.path.join(forecast_path, f"bootstraps_{str(station)}_{bootstrap_method}_labels.nc")
+                with xr.open_dataarray(file_name) as da:
+                    labels = da.load()
                 shape = labels.shape
 
                 # get original forecasts
                 orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), number_of_bootstraps)
                 orig = orig.reshape(shape)
                 coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"])
-                orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"])
+                orig = xr.DataArray(orig, coords=coords, dims=["index", self.ahead_dim, "type"])
 
                 # calculate skill scores for each variable
                 skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1))
-                for boot_set in bootstraps:
-                    boot_var = f"{boot_set[0]}_{boot_set[1]}"
-                    file_name = os.path.join(forecast_path, f"bootstraps_{station}_{boot_var}.nc")
-                    boot_data = xr.open_dataarray(file_name)
+                for boot_set in bootstrap_iter:
+                    boot_var = boot_set if isinstance(boot_set, str) else f"{boot_set[0]}_{boot_set[1]}"
+                    file_name = os.path.join(forecast_path,
+                                             f"bootstraps_{station}_{boot_var}_{bootstrap_type}_{bootstrap_method}.nc")
+                    # boot_data = xr.open_dataarray(file_name)
+                    with xr.open_dataarray(file_name) as da:
+                        boot_data = da.load()
                     boot_data = boot_data.combine_first(labels).combine_first(orig)
                     boot_scores = []
                     for ahead in range(1, self.window_lead_time + 1):
-                        data = boot_data.sel(ahead=ahead)
+                        data = boot_data.sel({self.ahead_dim: ahead})
                         boot_scores.append(
                             skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name="orig"))
                     skill.loc[boot_var] = np.array(boot_scores)
 
                 # collect all results in single dictionary
-                score[str(station)] = xr.DataArray(skill, dims=["boot_var", "ahead"])
+                score[str(station)] = xr.DataArray(skill, dims=["boot_var", self.ahead_dim])
             return score
 
     def get_orig_prediction(self, path, file_name, number_of_bootstraps, prediction_name=None):
         if prediction_name is None:
             prediction_name = self.forecast_indicator
         file = os.path.join(path, file_name)
-        prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
-        vals = np.tile(prediction.data, (number_of_bootstraps, 1))
+        # prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze()
+        with xr.open_dataarray(file) as da:
+            prediction = da.load().sel(type=prediction_name).squeeze()
+        return self.repeat_data(prediction, number_of_bootstraps)
+        # vals = np.tile(prediction.data, (number_of_bootstraps, 1))
+        # return vals[~np.isnan(vals).any(axis=1), :]
+
+    def repeat_data(self, data, number_of_repetition):
+        if isinstance(data, xr.DataArray):
+            data = data.data
+        vals = np.tile(data, (number_of_repetition, 1))
         return vals[~np.isnan(vals).any(axis=1), :]
 
     def _get_model_name(self):
@@ -317,9 +343,15 @@ class PostProcessing(RunEnvironment):
 
         try:
             if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list):
-                PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path,
-                                        model_setup=self.forecast_indicator, sampling=self._sampling,
-                                        ahead_dim=self.ahead_dim, separate_vars=to_list(self.target_var))
+                for boot_type, boot_data in self.bootstrap_skill_scores.items():
+                    for boot_method, boot_skill_score in boot_data.items():
+                        PlotBootstrapSkillScore(boot_skill_score, plot_folder=self.plot_path,
+                                                model_setup=self.forecast_indicator, sampling=self._sampling,
+                                                ahead_dim=self.ahead_dim, separate_vars=to_list(self.target_var),
+                                                bootstrap_type=boot_type, bootstrap_method=boot_method)
+                # PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path,
+                #                         model_setup=self.forecast_indicator, sampling=self._sampling,
+                #                         ahead_dim=self.ahead_dim, separate_vars=to_list(self.target_var))
         except Exception as e:
             logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}")
 
@@ -496,8 +528,9 @@ class PostProcessing(RunEnvironment):
         """
         path = os.path.join(self.competitor_path, competitor_name)
         file = os.path.join(path, f"forecasts_{station_name}_test.nc")
-        data = xr.open_dataarray(file)
-        # data = data.expand_dims(Stations=[station_name])  # ToDo: remove line
+        with xr.open_dataarray(file) as da:
+            data = da.load()
+        # data = xr.open_dataarray(file)
         forecast = data.sel(type=[self.forecast_indicator])
         forecast.coords["type"] = [competitor_name]
         return forecast
@@ -653,7 +686,9 @@ class PostProcessing(RunEnvironment):
         """
         try:
             file = os.path.join(path, f"forecasts_{str(station)}_train_val.nc")
-            return xr.open_dataarray(file)
+            with xr.open_dataarray(file) as da:
+                return da.load()
+            # return xr.open_dataarray(file)
         except (IndexError, KeyError, FileNotFoundError):
             return None
 
@@ -668,7 +703,9 @@ class PostProcessing(RunEnvironment):
         """
         try:
             file = os.path.join(path, f"forecasts_{str(station)}_test.nc")
-            return xr.open_dataarray(file)
+            with xr.open_dataarray(file) as da:
+                return da.load()
+            # return xr.open_dataarray(file)
         except (IndexError, KeyError, FileNotFoundError):
             return None
 
diff --git a/test/test_data_handler/old_t_bootstraps.py b/test/test_data_handler/old_t_bootstraps.py
index 9616ed3f..21c18c6c 100644
--- a/test/test_data_handler/old_t_bootstraps.py
+++ b/test/test_data_handler/old_t_bootstraps.py
@@ -160,7 +160,7 @@ class TestCreateShuffledData:
 
     def test_shuffle(self, shuffled_data_no_creation):
         dummy = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
-        res = shuffled_data_no_creation.shuffle(dummy, chunks=(2, 3)).compute()
+        res = shuffled_data_no_creation.apply_bootstrap_method(dummy, chunks=(2, 3)).compute()
         assert res.shape == dummy.shape
         assert dummy.max() >= res.max()
         assert dummy.min() <= res.min()
-- 
GitLab