From 696935148ab2a827b530bc679be8ff67430b6a44 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 27 Oct 2021 09:22:37 +0200
Subject: [PATCH] name of pdf starts now with feature_importance, there is now
 also another separated vars plot for single input feature importance

---
 mlair/plotting/postprocessing_plotting.py | 38 ++++++++++++++++-------
 1 file changed, 27 insertions(+), 11 deletions(-)

diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index d9551815..92a327a6 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -622,7 +622,7 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
         :param bootstrap_annotation: additional information to use in the file name (default: None)
         """
         annotation = ["_".join([s for s in ["", bootstrap_type, bootstrap_method] if s is not None])][0]
-        super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}{annotation}")
+        super().__init__(plot_folder, f"feature_importance_{model_setup}{annotation}")
         if separate_vars is None:
             separate_vars = ['o3']
         self._labels = None
@@ -645,6 +645,10 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
                 self._plot(branch=branch)
                 self.plot_name = f"{plot_name}_{branch}"
                 self._save()
+                if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0:
+                    self.plot_name += '_separated'
+                    self._plot(branch=branch, separate_vars=separate_vars)
+                    self._save(bbox_inches='tight')
         else:
             self._plot()
             self._save()
@@ -673,6 +677,7 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
         """
         station_dim = "station"
         data = helpers.dict_to_xarray(data, station_dim).sortby(self._x_name)
+        data = data.transpose(station_dim, self._ahead_dim, self._boot_dim, self._x_name)
         if self._boot_type == "single input":
             number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_')
             new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_',
@@ -741,10 +746,10 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
         if separate_vars is None:
             self._plot_all_variables(branch)
         else:
-            self._plot_selected_variables(separate_vars)
+            self._plot_selected_variables(separate_vars, branch)
 
-    def _plot_selected_variables(self, separate_vars: List):
-        data = self._data
+    def _plot_selected_variables(self, separate_vars: List, branch=None):
+        data = self._data if branch is None else self._data[self._data["branch"] == str(branch)]
         self.raise_error_if_separate_vars_do_not_exist(data, separate_vars, self._x_name)
         all_variables = self._get_unique_values_from_column_of_df(data, self._x_name)
         remaining_vars = helpers.remove_items(all_variables, separate_vars)
@@ -761,14 +766,21 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
 
         sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_first, ax=ax[0], whis=1.5,
                     palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
-                    flierprops={"marker": "."}, width=first_box_width)
+                    showfliers=False, width=first_box_width)
         ax[0].set(ylabel=f"skill score", xlabel="")
+        if self._ylim is not None:
+            _ylim = self._ylim if isinstance(self._ylim, tuple) else self._ylim[0]
+            ax[0].set(ylim=_ylim)
 
         sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1.5,
                     palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
-                    flierprops={"marker": "."})
+                    showfliers=False)
         ax[1].set(ylabel="", xlabel="")
         ax[1].yaxis.tick_right()
+        if self._ylim is not None and isinstance(self._ylim, list):
+            _ylim = self._ylim[1]
+            ax[1].set(ylim=_ylim)
+
         handles, _ = ax[1].get_legend_handles_labels()
         for sax in ax:
             matplotlib.pyplot.sca(sax)
@@ -835,17 +847,21 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
         plot_data = self._data if branch is None else self._data[self._data["branch"] == str(branch)]
         if self._boot_type == "branch":
             fig, ax = plt.subplots(figsize=(0.5 + 2 / len(plot_data[self._x_name].unique()) + len(plot_data[self._x_name].unique()),4))
-            sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1., palette="Blues_d",
-                    showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."},
-                    width=0.8)
+            sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.,
+                        palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
+                        showfliers=False, width=0.8)
         else:
             fig, ax = plt.subplots()
             sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=plot_data, ax=ax, whis=1.5, palette="Blues_d",
-                        showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
+                        showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, showfliers=False)
         ax.axhline(y=0, color="grey", linewidth=.5)
 
         if self._ylim is not None:
-            ax.set(ylim=self._ylim)
+            if isinstance(self._ylim, tuple):
+                _ylim = self._ylim
+            else:
+                _ylim = (min(self._ylim[0][0], self._ylim[1][0]), max(self._ylim[0][1], self._ylim[1][1]))
+            ax.set(ylim=_ylim)
 
         if self._boot_type == "branch":
             plt.xticks()
-- 
GitLab