From b358896149ad49234c2641a10b25810c4657ee56 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 15 Jul 2021 12:16:40 +0200
Subject: [PATCH] updated PlotBootstrapSkillScore

---
 mlair/plotting/postprocessing_plotting.py | 71 +++++++++++++----------
 mlair/run_modules/post_processing.py      |  3 +-
 2 files changed, 41 insertions(+), 33 deletions(-)

diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 491aa52e..75249e11 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -608,7 +608,8 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
 
     """
 
-    def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None):
+    def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = "", separate_vars: List = None,
+                 sampling: str = "daily", ahead_dim: str = "ahead"):
         """
         Set attributes and create plot.
 
@@ -616,20 +617,23 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         :param plot_folder: path to save the plot (default: current directory)
         :param model_setup: architecture type to specify plot name (default "CNN")
         :param separate_vars: variables to plot separated (default: ['o3'])
+        :param sampling: type of sampling rate, should be either hourly or daily (default: "daily")
+        :param ahead_dim: name of the ahead dimensions (default: "ahead")
         """
         super().__init__(plot_folder, f"skill_score_bootstrap_{model_setup}")
         if separate_vars is None:
             separate_vars = ['o3']
         self._labels = None
         self._x_name = "boot_var"
-        self._data = self._prepare_data(data)
+        self._ahead_dim = ahead_dim
+        self._data = self._prepare_data(data, sampling)
         self._plot()
         self._save()
         self.plot_name += '_separated'
         self._plot(separate_vars=separate_vars)
         self._save(bbox_inches='tight')
 
-    def _prepare_data(self, data: Dict) -> pd.DataFrame:
+    def _prepare_data(self, data: Dict, sampling: str) -> pd.DataFrame:
         """
         Shrink given data, if only scores are relevant.
 
@@ -640,23 +644,33 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
         :return: pre-processed data set
         """
         data = helpers.dict_to_xarray(data, "station").sortby(self._x_name)
-        new_boot_coords = self._return_vars_without_number_tag(data.coords['boot_var'].values, split_by='_', keep=1)
-        data = data.assign_coords({'boot_var': new_boot_coords})
-        self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
+        new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', keep=1)
+        data = data.assign_coords({self._x_name: new_boot_coords})
+        _, sampling_letter = self._get_target_sampling(sampling, 1)
+        # sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
+        # sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[1], "")
+        self._labels = [str(i) + sampling_letter for i in data.coords[self._ahead_dim].values]
         if "station" not in data.dims:
             data = data.expand_dims("station")
         return data.to_dataframe("data").reset_index(level=[0, 1, 2])
 
+    @staticmethod
+    def _get_target_sampling(sampling, pos):
+        sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
+        sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
+        return sampling, sampling_letter
+
     def _return_vars_without_number_tag(self, values, split_by, keep):
         arr = np.array([v.split(split_by) for v in values])
         num = arr[:, 0]
+        if arr.shape[keep] == 1:  # keep dim has only length 1, no number tags required
+            return num
         new_val = arr[:, keep]
         if self._all_values_are_equal(num, axis=0):
             return new_val
         else:
             raise NotImplementedError
 
-
     @staticmethod
     def _all_values_are_equal(arr, axis=0):
         if np.all(arr == arr[0], axis=axis):
@@ -681,37 +695,29 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
             self._plot_selected_variables(separate_vars)
 
     def _plot_selected_variables(self, separate_vars: List):
-        # if separate_vars is None:
-        #     separate_vars = ['o3']
         data = self._data
-        self.raise_error_if_separate_vars_do_not_exist(data, separate_vars)
-        all_variables = self._get_unique_values_from_column_of_df(data, 'boot_var')
+        self.raise_error_if_separate_vars_do_not_exist(data, separate_vars, self._x_name)
+        all_variables = self._get_unique_values_from_column_of_df(data, self._x_name)
         # remaining_vars = helpers.list_pop(all_variables, separate_vars) #remove_items
         remaining_vars = helpers.remove_items(all_variables, separate_vars)
-        data_first = self._select_data(df=data, variables=separate_vars, column_name='boot_var')
-        data_second = self._select_data(df=data, variables=remaining_vars, column_name='boot_var')
-
-        fig, ax = plt.subplots(nrows=1, ncols=2,
-                               gridspec_kw={'width_ratios': [len(separate_vars),
-                                                             len(remaining_vars)
-                                                             ]
-                                            }
-                               )
+        data_first = self._select_data(df=data, variables=separate_vars, column_name=self._x_name)
+        data_second = self._select_data(df=data, variables=remaining_vars, column_name=self._x_name)
+
+        fig, ax = plt.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [len(separate_vars),
+                                                                               len(remaining_vars)]})
         if len(separate_vars) > 1:
             first_box_width = .8
         else:
             first_box_width = 2.
 
-        sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_first, ax=ax[0], whis=1., palette="Blues_d",
-                    showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
-                    flierprops={"marker": "."}, width=first_box_width
-                    )
+        sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_first, ax=ax[0], whis=1.,
+                    palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
+                    flierprops={"marker": "."}, width=first_box_width)
         ax[0].set(ylabel=f"skill score", xlabel="")
 
-        sns.boxplot(x=self._x_name, y="data", hue="ahead", data=data_second, ax=ax[1], whis=1., palette="Blues_d",
-                    showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
-                    flierprops={"marker": "."},
-                    )
+        sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=data_second, ax=ax[1], whis=1.,
+                    palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"},
+                    flierprops={"marker": "."})
         ax[1].set(ylabel="", xlabel="")
         ax[1].yaxis.tick_right()
         handles, _ = ax[1].get_legend_handles_labels()
@@ -749,6 +755,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
 
     @staticmethod
     def _select_data(df: pd.DataFrame, variables: List[str], column_name: str) -> pd.DataFrame:
+        selected_data = None
         for i, variable in enumerate(variables):
             if i == 0:
                 selected_data = df.loc[df[column_name] == variable]
@@ -757,15 +764,15 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
                 selected_data = pd.concat([selected_data, tmp_var], axis=0)
         return selected_data
 
-    def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars):
-        if not self._variables_exist_in_df(df=data, variables=separate_vars):
+    def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars, column_name):
+        if not self._variables_exist_in_df(df=data, variables=separate_vars, column_name=column_name):
             raise ValueError(f"At least one entry of `separate_vars' does not exist in `self.data' ")
 
     @staticmethod
     def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List:
         return list(df[column_name].unique())
 
-    def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str = 'boot_var'):
+    def _variables_exist_in_df(self, df: pd.DataFrame, variables: List[str], column_name: str):
         vars_in_df = set(self._get_unique_values_from_column_of_df(df, column_name))
         return set(variables).issubset(vars_in_df)
 
@@ -774,7 +781,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass):
 
         """
         fig, ax = plt.subplots()
-        sns.boxplot(x=self._x_name, y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d",
+        sns.boxplot(x=self._x_name, y="data", hue=self._ahead_dim, data=self._data, ax=ax, whis=1., palette="Blues_d",
                     showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
         ax.axhline(y=0, color="grey", linewidth=.5)
         plt.xticks(rotation=45)
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 0d7bfeb4..f6eec3c8 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -318,7 +318,8 @@ class PostProcessing(RunEnvironment):
         try:
             if (self.bootstrap_skill_scores is not None) and ("PlotBootstrapSkillScore" in plot_list):
                 PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path,
-                                        model_setup=self.forecast_indicator)
+                                        model_setup=self.forecast_indicator, sampling=self._sampling,
+                                        ahead_dim=self.ahead_dim, separate_vars=to_list(self.target_var))
         except Exception as e:
             logging.error(f"Could not create plot PlotBootstrapSkillScore due to the following error: {e}")
 
-- 
GitLab