From e7c2c32ba84f58b2b5c58b4eb87c1488bfb38560 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Mon, 24 Jan 2022 15:06:51 +0100
Subject: [PATCH] updete boot_feature plots

---
 mlair/data_handler/input_bootstraps.py    | 29 ++---------------------
 mlair/plotting/postprocessing_plotting.py | 27 +++++++++++----------
 2 files changed, 17 insertions(+), 39 deletions(-)

diff --git a/mlair/data_handler/input_bootstraps.py b/mlair/data_handler/input_bootstraps.py
index b434cebd..2c0027f4 100644
--- a/mlair/data_handler/input_bootstraps.py
+++ b/mlair/data_handler/input_bootstraps.py
@@ -178,33 +178,6 @@ class BootstrapIteratorVariableSets(BootstrapIterator):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        #self.variable_set_splitters = ['Sect', 'SectLeft', 'SectRight']
-
-    # def __next__(self):
-    #     try:
-    #         _X, _Y, (index, dimensions) = self._prepare_data_for_next()
-    #         for dimension in dimensions:  # _X[index].coords[self._dimension].values:
-    #             single_variable = _X[index].sel({self._dimension: [dimension]})
-    #             bootstrapped_variable = self.apply_bootstrap_method(single_variable.values)
-    #             bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords,
-    #                                              dims=single_variable.dims)
-    #             _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims)
-    #
-    #         # for dimension in _X[index].coords[self._dimension].values:
-    #         #     single_variable = _X[index].sel({self._dimension: [dimension]})
-    #         #     bootstrapped_variable = self.apply_bootstrap_method(single_variable.values)
-    #         #     bootstrapped_data = xr.DataArray(bootstrapped_variable, coords=single_variable.coords,
-    #         #                                      dims=single_variable.dims)
-    #         #     _X[index] = bootstrapped_data.combine_first(_X[index]).transpose(*_X[index].dims)
-    #         self._position += 1
-    #     except IndexError:
-    #         StopIteration()
-    #     except Exception:
-    #         pass
-    #
-    #     _X, _Y = self._to_numpy(_X), self._to_numpy(_Y)
-    #     return self._reshape(_X), self._reshape(_Y), (index, dimensions)
-    #     # return self._reshape(_X), self._reshape(_Y), (None, index)
 
     def __next__(self):
         try:
@@ -237,6 +210,8 @@ class BootstrapIteratorVariableSets(BootstrapIterator):
         # l[0] = l[0] + ['o3Sect', 'o3SectLeft', 'o3SectRight', 'no2Sect', 'no2SectLeft', 'no2SectRight']
 
         res = [[var for var in l[i] if var.endswith(collection_name)] for collection_name in cls._variable_set_splitters]
+        base_vars = [var for var in l[i] if not var.endswith(tuple(cls._variable_set_splitters))]
+        res.append(base_vars)
         res = [(i, dimensions) for i, _ in enumerate(data.get_X(as_numpy=False)) for dimensions in res]
         return res
         # return list(chain(*res))
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 18c4a9e4..4f5107ac 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -129,7 +129,7 @@ class PlotMonthlySummary(AbstractPlotClass):  # pragma: no cover
         """
         data = self._data.to_dataset(name='values').to_dask_dataframe()
         logging.debug("... start plotting")
-        color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d",
+        color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_r",
                                                                                 self._window_lead_time).as_hex()
         ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1.5, palette=color_palette,
                          flierprops={'marker': '.', 'markersize': 1}, showmeans=True,
@@ -465,7 +465,7 @@ class PlotClimatologicalSkillScore(AbstractPlotClass):  # pragma: no cover
         fig, ax = plt.subplots()
         if not score_only:
             fig.set_size_inches(11.7, 8.27)
-        sns.boxplot(x="terms", y="data", hue="ahead", data=self._data, ax=ax, whis=1.5, palette="Blues_d",
+        sns.boxplot(x="terms", y="data", hue="ahead", data=self._data, ax=ax, whis=1.5, palette="Blues_r",
                     showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
         ax.axhline(y=0, color="grey", linewidth=.5)
         ax.set(ylabel=f"{self._label_add(score_only)}skill score", xlabel="", title="summary of all stations",
@@ -557,7 +557,7 @@ class PlotCompetitiveSkillScore(AbstractPlotClass):  # pragma: no cover
         fig, ax = plt.subplots(figsize=(size, size * 0.8))
         data = self._filter_comparisons(self._data) if single_model_comparison is True else self._data
         order = self._create_pseudo_order(data)
-        sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d",
+        sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_r",
                     showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
                     order=order)
         ax.axhline(y=0, color="grey", linewidth=.5)
@@ -572,7 +572,7 @@ class PlotCompetitiveSkillScore(AbstractPlotClass):  # pragma: no cover
         fig, ax = plt.subplots()
         data = self._filter_comparisons(self._data) if single_model_comparison is True else self._data
         order = self._create_pseudo_order(data)
-        sns.boxplot(y="comparison", x="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d",
+        sns.boxplot(y="comparison", x="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_r",
                     showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
                     order=order)
         ax.axvline(x=0, color="grey", linewidth=.5)
@@ -638,7 +638,7 @@ class PlotSectorialSkillScore(AbstractPlotClass):  # pragma: no cover
         size = max([len(np.unique(self._data.sector)), 6])
         fig, ax = plt.subplots(figsize=(size, size * 0.8))
         data = self._data
-        sns.boxplot(x="sector", y="data", hue="ahead", data=data, whis=1, ax=ax, palette="Blues_d",
+        sns.boxplot(x="sector", y="data", hue="ahead", data=data, whis=1, ax=ax, palette="Blues_r",
                     showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
                     )
         ax.axhline(y=0, color="grey", linewidth=.5)
@@ -653,7 +653,7 @@ class PlotSectorialSkillScore(AbstractPlotClass):  # pragma: no cover
         """Plot skill scores of the comparisons, but vertically aligned."""
         fig, ax = plt.subplots()
         data = self._data
-        sns.boxplot(y="sector", x="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_d",
+        sns.boxplot(y="sector", x="data", hue="ahead", data=data, whis=1.5, ax=ax, palette="Blues_r",
                     showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
                     )
         ax.axvline(x=0, color="grey", linewidth=.5)
@@ -846,7 +846,10 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
         if self._boot_type == "group_of_variables":
             h = []
             for i, subset in enumerate(arr[:, keep]):
-                h.append(self.findstem(ast.literal_eval(subset)))
+                group_name = self.findstem(ast.literal_eval(subset))
+                if group_name == '':
+                    group_name = "Base"
+                h.append(group_name)
             new_val = h
         else:
             new_val = arr[:, keep]
@@ -973,7 +976,7 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
             first_box_width = .8
 
         sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_first, ax=ax[0], whis=1.,
-                    palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
+                    palette="Blues_r", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
                     showfliers=False, width=first_box_width)
         ax[0].set(ylabel=f"skill score", xlabel="")
         if self._ylim is not None:
@@ -981,7 +984,7 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
             ax[0].set(ylim=_ylim)
 
         sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1.5,
-                    palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
+                    palette="Blues_r", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
                     showfliers=False, flierprops={"marker": "."})
 
         ax[1].set(ylabel="", xlabel="")
@@ -1074,11 +1077,11 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
         if self._boot_type == "branch":
             fig, ax = plt.subplots(figsize=(0.5 + 2 / len(plot_data[self._x_name].unique()) + len(plot_data[self._x_name].unique()),4))
             sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.,
-                        palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
+                        palette="Blues_r", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
                         showfliers=False, width=0.8)
         else:
             fig, ax = plt.subplots()
-            sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.5, palette="Blues_d",
+            sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.5, palette="Blues_r",
                         showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, showfliers=False)
         ax.axhline(y=0, color="grey", linewidth=.5)
         #<<<<<<< HEAD
@@ -1210,7 +1213,7 @@ class PlotTimeSeries:  # pragma: no cover
         return f, ax[:, 0], factor
 
     def _plot_ahead(self, ax, data):
-        color = sns.color_palette("Blues_d", self._window_lead_time).as_hex()
+        color = sns.color_palette("Blues_r", self._window_lead_time).as_hex()
         for ahead in data.coords[self._ahead_dim].values:
             plot_data = data.sel({"type": self._model_name, self._ahead_dim: ahead}).drop(["type", self._ahead_dim]).squeeze().shift(index=ahead)
             label = f"{ahead}{self._sampling}"
-- 
GitLab