From be3075e93416be180d7797ce8897540eb5e7d2fd Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Thu, 10 Feb 2022 11:51:32 +0100
Subject: [PATCH] plot sectorial skill scores for all competitors

---
 mlair/helpers/statistics.py               |  8 ++++-
 mlair/plotting/postprocessing_plotting.py | 41 ++++++++++++++---------
 mlair/run_modules/post_processing.py      | 10 ++++--
 3 files changed, 40 insertions(+), 19 deletions(-)

diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index 02803225..11c71ae7 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -12,6 +12,7 @@ from typing import Union, Tuple, Dict, List
 import itertools
 import dask.array as da
 from collections import OrderedDict
+from mlair.helpers import to_list
 
 Data = Union[xr.DataArray, pd.DataFrame]
 
@@ -235,7 +236,12 @@ def skill_score_based_on_mse(data: xr.DataArray, obs_name: str, pred_name: str,
                              aggregation_dim: str = "index", competitor_dim: str = "type") -> xr.DataArray:
     obs = data.sel({competitor_dim: obs_name})
     pred = data.sel({competitor_dim: pred_name})
-    ref = data.sel({competitor_dim: ref_name})
+    # ref = data.sel({competitor_dim: ref_name})
+    href = []
+    for ref_n in to_list(ref_name):
+        href.append(data.sel({competitor_dim: ref_n}))
+    ref = xr.concat(href, dim=competitor_dim).transpose(*data.dims)
+
     ss = 1 - mean_squared_error_nan(obs, pred, dim=aggregation_dim) / mean_squared_error_nan(obs, ref, dim=aggregation_dim)
     return ss
 
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 335e6739..fa121fe8 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -632,46 +632,57 @@ class PlotCompetitiveSkillScore(AbstractPlotClass):  # pragma: no cover
 class PlotSectorialSkillScore(AbstractPlotClass):  # pragma: no cover
 
     def __init__(self, data: xr.DataArray, plot_folder: str = ".", model_setup: str = "NN", sampling: str = "daily",
-                 model_name_for_plots: Union[str, None] = None, ahead_dim: str = "ahead", ):
+                 model_name_for_plots: Union[str, None] = None, ahead_dim: str = "ahead", reference_dim: str = "type"):
         """Initialise."""
         super().__init__(plot_folder, f"skill_score_sectorial_{model_setup}")
         self._model_setup = model_setup
         self._sampling = self._get_sampling(sampling)
         self._ahead_dim = ahead_dim
+        self._reference_dim = reference_dim
         self._labels = None
         self._model_name_for_plots = model_name_for_plots
         self._data, self._reference_model = self._prepare_data(data)
         logging.info("PlotSectorialSkillScore: finished _prepare_data(data)")
         self._plot()
         logging.info("PlotSectorialSkillScore: finished _plot()")
-        self._save()
         self.plot_name = self.plot_name + "_vertical"
         self._plot_vertical()
         logging.info("PlotSectorialSkillScore: finished _plot_vertical()")
-        self._save()
 
     @TimeTrackingWrapper
     def _prepare_data(self, data: xr.DataArray):
         self._labels = [str(i) + self._sampling for i in data.coords[self._ahead_dim].values]
         reference_model = data.attrs["reference_model"]
         logging.info(f"PlotSectorialSkillScore._prepare_data: shape of data (xarray) is {data.shape}\n dims: {data.dims}")
-        data = data.to_dataframe("data")[["data"]].stack(level=0).reset_index(level=3, drop=True).reset_index(name="data")
+        # data = data.to_dataframe("data")[["data"]].stack(level=0).reset_index(level=3, drop=True).reset_index(name="data")
+        data = data.to_dataframe("data")[["data"]].stack(level=0).reset_index(level=4, drop=True).reset_index(
+            name="data")
         return data, reference_model
 
     def _plot(self):
         size = max([len(np.unique(self._data.sector)), 6])
-        fig, ax = plt.subplots(figsize=(size, size * 0.8))
         data = self._data
-        sns.boxplot(x="sector", y="data", hue="ahead", data=data, whis=1, ax=ax, palette="Blues_r",
-                    showmeans=False, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
-                    )
-        ax.axhline(y=0, color="grey", linewidth=.5)
-        ax.set(ylabel=f"skill score ({self._model_setup} vs. {self._reference_model})", xlabel="sector",
-               title="summary of all stations", ylim=self._lim(data))
-        handles, _ = ax.get_legend_handles_labels()
-        plt.xticks(rotation=45, horizontalalignment="right")
-        ax.legend(handles, self._labels)
-        plt.tight_layout()
+        plot_path = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf")
+        pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
+        for ref in self._reference_model:
+            ref_data = data[data[self._reference_dim]==ref]
+            fig, ax = plt.subplots(figsize=(size, size * 0.8))
+            sns.boxplot(x="sector", y="data", hue="ahead", data=ref_data,
+                        whis=1, ax=ax, palette="Blues_r",
+                        showmeans=False, #meanprops={"markersize": 3, "markeredgecolor": "k"},
+                        flierprops={"marker": "."},
+                        )
+            ax.axhline(y=0, color="grey", linewidth=.5)
+            ax.set(ylabel=f"skill score ({self._model_setup} vs. {ref})", xlabel="sector",
+                   title="summary of all stations", ylim=self._lim(ref_data))
+            handles, _ = ax.get_legend_handles_labels()
+            plt.xticks(rotation=45, horizontalalignment="right")
+            ax.legend(handles, self._labels)
+            plt.tight_layout()
+            pdf_pages.savefig()
+        # close all open figures / plots
+        pdf_pages.close()
+        plt.close('all')
 
     def _plot_vertical(self):
         """Plot skill scores of the comparisons, but vertically aligned."""
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index d529bc71..452e2d6d 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -101,6 +101,7 @@ class PostProcessing(RunEnvironment):
         self.index_dim = "index"
         self.iter_dim = self.data_store.get("iter_dim")
         self.upstream_wind_sector = None
+        self.model_and_competitor_name_list = None
         self.model_display_name = self.data_store.get_default("model_display_name", default=self.model.model_name)
         self._run()
 
@@ -603,7 +604,8 @@ class PostProcessing(RunEnvironment):
             if "PlotSectorialSkillScore" in plot_list:
                 PlotSectorialSkillScore(self.skill_score_per_sector, plot_folder=self.plot_path,
                                         model_setup=self.model_display_name, sampling=self._sampling,
-                                        model_name_for_plots=self.model_name_for_plots, ahead_dim=self.ahead_dim
+                                        model_name_for_plots=self.model_name_for_plots, ahead_dim=self.ahead_dim,
+                                        reference_dim=self.model_type_dim
                                         )
         except Exception as e:
             logging.error(f"Could not create plot PlotSectorialSkillScore due to the following error: {e}"
@@ -992,7 +994,7 @@ class PostProcessing(RunEnvironment):
             external_data_expd.to_netcdf(os.path.join(path, f"forecasts_ds_{str(station)}_test.nc"))
     
     @TimeTrackingWrapper
-    def calculate_error_metrics_based_on_upstream_wind_dir(self, ref_name: str = "ols") -> xr.DataArray:
+    def calculate_error_metrics_based_on_upstream_wind_dir(self, ref_name: str = ["ols", "persi"]) -> xr.DataArray:
         """
         Calculates the error metrics (mse)/skill scores based on the wind sector at time t0.
 
@@ -1023,7 +1025,8 @@ class PostProcessing(RunEnvironment):
                                                                                      )
             )
         sector_skill_scores = xr.concat(h_sector_skill_scores, dim="sector")
-        sector_skill_scores = sector_skill_scores.assign_attrs({f"reference_model": ref_name})
+        sector_skill_scores = sector_skill_scores.assign_attrs({f"reference_model": ref_name,
+                                                                "reference_model_dim": self.model_type_dim})
         return sector_skill_scores
 
     def calculate_error_metrics(self) -> Tuple[Dict, Dict, Dict, Dict]:
@@ -1066,6 +1069,7 @@ class PostProcessing(RunEnvironment):
             else:
                 model_list = None
 
+            self.model_and_competitor_name_list = model_list
             # test errors of competitors
             for model_type in (model_list or []):
                 if self.observation_indicator not in combined.coords[self.model_type_dim]:
-- 
GitLab