From 9c5ba0de6a27e8d44baea59d61aab8a4bca2d6a3 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Thu, 27 Jan 2022 11:43:06 +0100
Subject: [PATCH] include feature boots set: grouped by variable across all
 sectors

---
 mlair/data_handler/input_bootstraps.py    | 42 ++++++++++++++++++-----
 mlair/plotting/postprocessing_plotting.py |  7 ++--
 2 files changed, 38 insertions(+), 11 deletions(-)

diff --git a/mlair/data_handler/input_bootstraps.py b/mlair/data_handler/input_bootstraps.py
index 287beea2..3c06f7b6 100644
--- a/mlair/data_handler/input_bootstraps.py
+++ b/mlair/data_handler/input_bootstraps.py
@@ -94,11 +94,6 @@ class BootstrapIteratorSingleInput(BootstrapIterator):
         """Return next element or stop iteration."""
         try:
             _X, _Y, (index, dimension) = self._prepare_data_for_next()
-            # index, dimension = self._collection[self._position]
-            # nboot = self._data.number_of_bootstraps
-            # _X, _Y = self._data.data.get_data(as_numpy=False)
-            # _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X))
-            # _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1)
             single_variable = _X[index].sel({self._dimension: [dimension]})
             bootstrapped_variable = self.apply_bootstrap_method(single_variable.values)
             bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords,
@@ -173,7 +168,7 @@ class BootstrapIteratorBranch(BootstrapIterator):
         return list(range(len(data.get_X(as_numpy=False))))
 
 
-class BootstrapIteratorVariableSets(BootstrapIterator):
+class BootstrapIteratorSets(BootstrapIterator):
     _variable_set_splitters: list = ['Sect', 'SectLeft', 'SectRight',]
 
     def __init__(self, *args, **kwargs):
@@ -192,10 +187,18 @@ class BootstrapIteratorVariableSets(BootstrapIterator):
         except IndexError:
             raise StopIteration()
         _X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
-        # dimension_return_by_seperator = [sec for i, var in enumerate(dimensions) for sec in self._variable_set_splitters if (var.endswith(sec) and i ==0)]
-        # return self._reshape(_X), self._reshape(_Y), (index, dimension_return_by_seperator[0])
         return self._reshape(_X), self._reshape(_Y), (index, dimensions)
 
+    @classmethod
+    def create_collection(cls, data, dim):
+        raise NotImplementedError
+
+
+class BootstrapIteratorSetsSector(BootstrapIteratorSets):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
     @classmethod
     def create_collection(cls, data, dim):
         l = []
@@ -211,6 +214,26 @@ class BootstrapIteratorVariableSets(BootstrapIterator):
         return res
 
 
+class BootstrapIteratorSetsVariablesInAllSectors(BootstrapIteratorSets):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    @classmethod
+    def create_collection(cls, data, dim):
+        l = []
+        for i, x in enumerate(data.get_X(as_numpy=False)):
+            l.append(x.indexes[dim].to_list())
+
+        base_vars = sorted([var for var in l[i] if not var.endswith(tuple(cls._variable_set_splitters)) for i, _ in
+                            enumerate(data.get_X(as_numpy=False))])
+
+        res = [[(var + sec) for sec in [''] + cls._variable_set_splitters] for var in base_vars]
+
+        res = [(i, dimensions) for i, _ in enumerate(data.get_X(as_numpy=False)) for dimensions in res]
+        return res
+
+
 class ShuffleBootstraps:
 
     @staticmethod
@@ -268,7 +291,8 @@ class Bootstraps(Iterable):
         self.BootstrapIterator = {"singleinput": BootstrapIteratorSingleInput,
                                   "branch": BootstrapIteratorBranch,
                                   "variable": BootstrapIteratorVariable,
-                                  "group_of_variables": BootstrapIteratorVariableSets,
+                                  "group_of_variables_sector": BootstrapIteratorSetsSector,
+                                  "group_of_variables_var_in_sectors": BootstrapIteratorSetsVariablesInAllSectors,
                                   }.get(bootstrap_type, BootstrapIteratorSingleInput)
 
     def __iter__(self):
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index dfffe6f2..3aef2792 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -765,7 +765,8 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
 
     def _set_title(self, model_name, branch=None):
         title_d = {"single input": "Single Inputs", "branch": "Input Branches", "variable": "Variables",
-                   "group_of_variables": "grouped variables"}
+                   "group_of_variables_sector": "grouped variables by sector",
+                   "group_of_variables_var_in_sectors": "grouped variables across sectors"}
         base_title = f"{model_name}\nImportance of {title_d[self._boot_type]}"
 
         additional = []
@@ -845,12 +846,14 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
         num = arr[:, 0]
         if arr.shape[keep] == 1:  # keep dim has only length 1, no number tags required
             return num
-        if self._boot_type == "group_of_variables":
+        if self._boot_type in ["group_of_variables_sector", "group_of_variables_var_in_sectors"]:
             h = []
             for i, subset in enumerate(arr[:, keep]):
                 group_name = self.findstem(ast.literal_eval(subset))
                 if group_name == '':
                     group_name = "Base"
+                if self._boot_type == "group_of_variables_var_in_sectors":
+                    group_name = group_name + "*"
                 h.append(group_name)
             new_val = h
         else:
-- 
GitLab