From 9c5ba0de6a27e8d44baea59d61aab8a4bca2d6a3 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Thu, 27 Jan 2022 11:43:06 +0100 Subject: [PATCH] include feature boots set: grouped by variable across all sectors --- mlair/data_handler/input_bootstraps.py | 42 ++++++++++++++++++----- mlair/plotting/postprocessing_plotting.py | 7 ++-- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/mlair/data_handler/input_bootstraps.py b/mlair/data_handler/input_bootstraps.py index 287beea2..3c06f7b6 100644 --- a/mlair/data_handler/input_bootstraps.py +++ b/mlair/data_handler/input_bootstraps.py @@ -94,11 +94,6 @@ class BootstrapIteratorSingleInput(BootstrapIterator): """Return next element or stop iteration.""" try: _X, _Y, (index, dimension) = self._prepare_data_for_next() - # index, dimension = self._collection[self._position] - # nboot = self._data.number_of_bootstraps - # _X, _Y = self._data.data.get_data(as_numpy=False) - # _X = list(map(lambda x: x.expand_dims({self.boot_dim: range(nboot)}, axis=-1), _X)) - # _Y = _Y.expand_dims({self.boot_dim: range(nboot)}, axis=-1) single_variable = _X[index].sel({self._dimension: [dimension]}) bootstrapped_variable = self.apply_bootstrap_method(single_variable.values) bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords, @@ -173,7 +168,7 @@ class BootstrapIteratorBranch(BootstrapIterator): return list(range(len(data.get_X(as_numpy=False)))) -class BootstrapIteratorVariableSets(BootstrapIterator): +class BootstrapIteratorSets(BootstrapIterator): _variable_set_splitters: list = ['Sect', 'SectLeft', 'SectRight',] def __init__(self, *args, **kwargs): @@ -192,10 +187,18 @@ class BootstrapIteratorVariableSets(BootstrapIterator): except IndexError: raise StopIteration() _X, _Y = self._to_numpy(_X), self._to_numpy(_Y) - # dimension_return_by_seperator = [sec for i, var in enumerate(dimensions) for sec in self._variable_set_splitters if (var.endswith(sec) and i ==0)] - # return self._reshape(_X), self._reshape(_Y), (index, dimension_return_by_seperator[0]) return self._reshape(_X), self._reshape(_Y), (index, dimensions) + @classmethod + def create_collection(cls, data, dim): + raise NotImplementedError + + +class BootstrapIteratorSetsSector(BootstrapIteratorSets): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + @classmethod def create_collection(cls, data, dim): l = [] @@ -211,6 +214,26 @@ class BootstrapIteratorVariableSets(BootstrapIterator): return res +class BootstrapIteratorSetsVariablesInAllSectors(BootstrapIteratorSets): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def create_collection(cls, data, dim): + l = [] + for i, x in enumerate(data.get_X(as_numpy=False)): + l.append(x.indexes[dim].to_list()) + + base_vars = sorted([var for var in l[i] if not var.endswith(tuple(cls._variable_set_splitters)) for i, _ in + enumerate(data.get_X(as_numpy=False))]) + + res = [[(var + sec) for sec in [''] + cls._variable_set_splitters] for var in base_vars] + + res = [(i, dimensions) for i, _ in enumerate(data.get_X(as_numpy=False)) for dimensions in res] + return res + + class ShuffleBootstraps: @staticmethod @@ -268,7 +291,8 @@ class Bootstraps(Iterable): self.BootstrapIterator = {"singleinput": BootstrapIteratorSingleInput, "branch": BootstrapIteratorBranch, "variable": BootstrapIteratorVariable, - "group_of_variables": BootstrapIteratorVariableSets, + "group_of_variables_sector": BootstrapIteratorSetsSector, + "group_of_variables_var_in_sectors": BootstrapIteratorSetsVariablesInAllSectors, }.get(bootstrap_type, BootstrapIteratorSingleInput) def __iter__(self): diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index dfffe6f2..3aef2792 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -765,7 +765,8 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover def _set_title(self, model_name, branch=None): title_d = {"single input": "Single Inputs", "branch": "Input Branches", "variable": "Variables", - "group_of_variables": "grouped variables"} + "group_of_variables_sector": "grouped variables by sector", + "group_of_variables_var_in_sectors": "grouped variables across sectors"} base_title = f"{model_name}\nImportance of {title_d[self._boot_type]}" additional = [] @@ -845,12 +846,14 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover num = arr[:, 0] if arr.shape[keep] == 1: # keep dim has only length 1, no number tags required return num - if self._boot_type == "group_of_variables": + if self._boot_type in ["group_of_variables_sector", "group_of_variables_var_in_sectors"]: h = [] for i, subset in enumerate(arr[:, keep]): group_name = self.findstem(ast.literal_eval(subset)) if group_name == '': group_name = "Base" + if self._boot_type == "group_of_variables_var_in_sectors": + group_name = group_name + "*" h.append(group_name) new_val = h else: -- GitLab