From b2003ea31db70cc2f74a39d7edea8bd8503fd05c Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 30 Aug 2022 14:24:44 +0200
Subject: [PATCH] new plot avail PlotSeasonalMSEStack, not in default plot
 list. \close #422

---
 mlair/plotting/postprocessing_plotting.py | 87 +++++++++++++++++++++++
 mlair/run_modules/post_processing.py      | 12 +++-
 2 files changed, 98 insertions(+), 1 deletion(-)

diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index c7647ef5..d25736f1 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -1319,6 +1319,93 @@ class PlotTimeEvolutionMetric(AbstractPlotClass):
         self._save()
 
 
+@TimeTrackingWrapper
+class PlotSeasonalMSEStack(AbstractPlotClass):
+
+    def __init__(self, data_path: str, plot_folder: str = ".", boot_dim="boots", ahead_dim="ahead",
+                 sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$"):
+        """Set attributes and create plot."""
+        super().__init__(plot_folder, "seasonal_mse_stack_plot")
+        self.plot_name_orig = "seasonal_mse_stack_plot"
+        self._data_path = data_path
+        self.season_dim = "season"
+        self.error_unit = error_unit
+        self.error_measure = error_measure
+        self._data = self._prepare_data(boot_dim, data_path)
+        for orientation in ["horizontal", "vertical"]:
+            for split_ahead in [True, False]:
+                self._plot(ahead_dim, split_ahead, sampling, orientation)
+                self._save(bbox_inches="tight")
+
+    def _prepare_data(self, boot_dim, data_path):
+        season_dim = self.season_dim
+        data = {}
+        for season in ["total", "DJF", "MAM", "JJA", "SON"]:
+            if season == "total":
+                file_name = "uncertainty_estimate_raw_results.nc"
+            else:
+                file_name = f"uncertainty_estimate_raw_results_{season}.nc"
+            with xr.open_dataarray(os.path.join(data_path, file_name)) as d:
+                data[season] = d
+        mean = {}
+        for season in data.keys():
+            mean[season] = data[season].mean(boot_dim)
+        xr_data = xr.Dataset(mean).to_array(season_dim)
+        factor = xr_data.sel({season_dim: "total"}) / xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}).sum(
+            season_dim)
+        return xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}) * factor
+
+    @staticmethod
+    def _get_target_sampling(sampling, pos):
+        sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
+        sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
+        return sampling, sampling_letter
+
+    def _plot(self, dim, split_ahead=True, sampling="daily", orientation="vertical"):
+        _, sampling_letter = self._get_target_sampling(sampling, 1)
+        if split_ahead is True:
+            self.plot_name = self.plot_name_orig + "_total_" + orientation
+            data = self._data.mean(dim)
+            if orientation == "vertical":
+                fig, ax = plt.subplots(1, 1)
+                data.to_pandas().T.plot.bar(ax=ax, stacked=True, cmap="Dark2", legend=False)
+                ax.xaxis.label.set_visible(False)
+                ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})")
+            else:
+                m = data.to_pandas().T.shape[0]
+                fig, ax = plt.subplots(1, 1, figsize=(6, m))
+                data.to_pandas().T.plot.barh(ax=ax, stacked=True, cmap="Dark2", legend=False)
+                ax.yaxis.label.set_visible(False)
+                ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})")
+            fig.legend(*ax.get_legend_handles_labels(), loc="upper center", ncol=4)
+            fig.tight_layout(rect=[0, 0, 1, 0.9])
+        else:
+            self.plot_name = self.plot_name_orig + "_" + orientation
+            data = self._data
+            n = len(data.coords[dim])
+            if orientation == "vertical":
+                fig, ax = plt.subplots(1, n, sharey=True)
+                for i, sel in enumerate(data.coords[dim].values):
+                    data.sel({dim: sel}).to_pandas().T.plot.bar(ax=ax[i], stacked=True, cmap="Dark2", legend=False)
+                    label = str(sel) + sampling_letter
+                    ax[i].set_title(label)
+                    ax[i].xaxis.label.set_visible(False)
+                ax[0].set_ylabel(f"{self.error_measure} (in {self.error_unit})")
+                fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4)
+                fig.tight_layout(rect=[0, 0, 1, 0.9])
+            else:
+                m = data.max(self.season_dim).shape
+                fig, ax = plt.subplots(n, 1, sharex=True, figsize=(6, np.prod(m) * 0.6))
+                for i, sel in enumerate(data.coords[dim].values):
+                    data.sel({dim: sel}).to_pandas().T.plot.barh(ax=ax[i], stacked=True, cmap="Dark2", legend=False)
+                    label = str(sel) + sampling_letter
+                    ax[i].set_title(label)
+                    ax[i].yaxis.label.set_visible(False)
+                ax[-1].set_xlabel(f"{self.error_measure} (in {self.error_unit})")
+                fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4)
+                fig.tight_layout(rect=[0, 0, 1, 0.95])
+
+
 if __name__ == "__main__":
     stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
     path = "../../testrun_network/forecasts"
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index a48a82b2..d65a2001 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -24,7 +24,7 @@ from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
 from mlair.model_modules import AbstractModelClass
 from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \
     PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \
-    PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric
+    PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric, PlotSeasonalMSEStack
 from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \
     PlotPeriodogram, PlotDataHistogram
 from mlair.run_modules.run_environment import RunEnvironment
@@ -681,6 +681,16 @@ class PostProcessing(RunEnvironment):
             logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}"
                           f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
 
+        try:
+            if "PlotSeasonalMSEStack" in plot_list:
+                report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
+                PlotSeasonalMSEStack(data_path=report_path, plot_folder=self.plot_path,
+                                     boot_dim=self.uncertainty_estimate_boot_dim, ahead_dim=self.ahead_dim,
+                                     sampling=self._sampling, error_measure="Mean Squared Error", error_unit=r"ppb$^2$")
+        except Exception as e:
+            logging.error(f"Could not create plot PlotSeasonalMSEStack due to the following error: {e}"
+                          f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
+
     @TimeTrackingWrapper
     def calculate_test_score(self):
         """Evaluate test score of model and save locally."""
-- 
GitLab