From 42a969dac0e5048e3f1fca05cc6576380ea6ffe3 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 30 Sep 2022 16:14:47 +0200
Subject: [PATCH] PlotErrorsOnMap features now ahead split

---
 mlair/plotting/abstract_plot_class.py     |  9 ++--
 mlair/plotting/data_insight_plotting.py   |  4 +-
 mlair/plotting/postprocessing_plotting.py | 55 +++++++++++------------
 mlair/run_modules/post_processing.py      |  2 +-
 4 files changed, 33 insertions(+), 37 deletions(-)

diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py
index a26023bb..377dc6b7 100644
--- a/mlair/plotting/abstract_plot_class.py
+++ b/mlair/plotting/abstract_plot_class.py
@@ -92,11 +92,10 @@ class AbstractPlotClass:  # pragma: no cover
         plt.rcParams.update(self.rc_params)
 
     @staticmethod
-    def _get_sampling(sampling):
-        if sampling == "daily":
-            return "D"
-        elif sampling == "hourly":
-            return "h"
+    def _get_sampling(sampling, pos=1):
+        sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
+        sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
+        return sampling, sampling_letter
 
     @staticmethod
     def get_dataset_colors():
diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py
index db2b3340..d33f3abb 100644
--- a/mlair/plotting/data_insight_plotting.py
+++ b/mlair/plotting/data_insight_plotting.py
@@ -188,7 +188,7 @@ class PlotAvailability(AbstractPlotClass):  # pragma: no cover
         super().__init__(plot_folder, "data_availability")
         self.time_dim = time_dimension
         self.window_dim = window_dimension
-        self.sampling = self._get_sampling(sampling)
+        self.sampling = self._get_sampling(sampling)[1]
         self.linewidth = None
         if self.sampling == 'h':
             self.linewidth = 0.001
@@ -321,7 +321,7 @@ class PlotAvailabilityHistogram(AbstractPlotClass):  # pragma: no cover
     def _set_dims_from_datahandler(self, data_handler):
         self.temporal_dim = data_handler.id_class.time_dim
         self.target_dim = data_handler.id_class.target_dim
-        self.freq = self._get_sampling(data_handler.id_class.sampling)
+        self.freq = self._get_sampling(data_handler.id_class.sampling)[1]
 
     @property
     def allowed_plot_types(self):
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 62ab3de3..d1a68896 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -748,19 +748,13 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass):  # pragma: no cover
                 data = data.assign_coords({self._x_name: new_boot_coords})
             except NotImplementedError:
                 pass
-        _, sampling_letter = self._get_target_sampling(sampling, 1)
+        _, sampling_letter = self._get_sampling(sampling, 1)
         self._labels = [str(i) + sampling_letter for i in data.coords[self._ahead_dim].values]
         if station_dim not in data.dims:
             data = data.expand_dims(station_dim)
         self._number_of_bootstraps = np.unique(data.coords[self._boot_dim].values).shape[0]
         return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()).dropna()
 
-    @staticmethod
-    def _get_target_sampling(sampling, pos):
-        sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
-        sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
-        return sampling, sampling_letter
-
     def _return_vars_without_number_tag(self, values, split_by, keep, as_unique=False):
         arr = np.array([v.split(split_by) for v in values])
         num = arr[:, 0]
@@ -1462,12 +1456,6 @@ class PlotSeasonalMSEStack(AbstractPlotClass):
         season_share = xr_data.sel({season_dim: "total"}) * factor
         return season_share.sortby(season_share.sum([self.season_dim, self.ahead_dim])).transpose(*self.dim_order)
 
-    @staticmethod
-    def _get_target_sampling(sampling, pos):
-        sampling = (sampling, sampling) if isinstance(sampling, str) else sampling
-        sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "")
-        return sampling, sampling_letter
-
     @staticmethod
     def _set_bar_label(ax):
         opts = {}
@@ -1481,7 +1469,7 @@ class PlotSeasonalMSEStack(AbstractPlotClass):
             ax.bar_label(c, labels=_l, label_type='center')
 
     def _plot(self, dim, split_ahead=True, sampling="daily", orientation="vertical"):
-        _, sampling_letter = self._get_target_sampling(sampling, 1)
+        _, sampling_letter = self._get_sampling(sampling, 1)
         if split_ahead is False:
             self.plot_name = self.plot_name_orig + "_total_" + orientation
             data = self._data.mean(dim)
@@ -1534,26 +1522,34 @@ class PlotErrorsOnMap(AbstractPlotClass):
     from mlair.plotting.data_insight_plotting import PlotStationMap
 
     def __init__(self, data_gen, errors, error_metric, plot_folder: str = ".", iter_dim: str = "station",
-                 model_type_dim: str = "type", ahead_dim: str = "ahead"):
+                 model_type_dim: str = "type", ahead_dim: str = "ahead", sampling: str = "daily"):
 
         super().__init__(plot_folder, f"map_plot_{error_metric}")
-
         plot_path = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
         pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
         error_metric_units = helpers.statistics.get_error_metrics_units("ppb")[error_metric]
         error_metric_name = helpers.statistics.get_error_metrics_long_name()[error_metric]
+        self.sampling = self._get_sampling(sampling, 1)[1]
 
         coords = self._extract_coords(data_gen)
-        error_data = {}
-        for model_type in errors.coords[model_type_dim].values:
-            error_data[model_type] = self._prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric)
-
-        limits = self._calculate_limits(error_data)
-
-        for model_type, error in error_data.items():
-            plot_data = pd.concat([coords, error], axis=1)
-            self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits)
-            pdf_pages.savefig()
+        for split_ahead in [False, True]:
+            error_data = {}
+            for model_type in errors.coords[model_type_dim].values:
+                error_data[model_type] = self._prepare_data(errors, model_type_dim, model_type, ahead_dim, error_metric,
+                                                            split_ahead=split_ahead)
+            limits = self._calculate_limits(error_data)
+            for model_type, error in error_data.items():
+                if split_ahead is True:
+                    for ahead in error.index.unique(1).to_list():
+                        error_ahead = error.query(f"{ahead_dim} == {ahead}").droplevel(1)
+                        plot_data = pd.concat([coords, error_ahead], axis=1)
+                        self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits,
+                                  ahead=ahead)
+                        pdf_pages.savefig()
+                else:
+                    plot_data = pd.concat([coords, error], axis=1)
+                    self.plot(plot_data, error_metric, error_metric_name, error_metric_units, model_type, limits)
+                    pdf_pages.savefig()
         pdf_pages.close()
         plt.close('all')
 
@@ -1571,7 +1567,7 @@ class PlotErrorsOnMap(AbstractPlotClass):
         vmin = relative_round(bound_lims[0], 2, floor=True)
         vmax = relative_round(bound_lims[1], 2, ceil=True)
         interval = relative_round((vmax - vmin) / ncolors, 1, ceil=True)
-        bounds = np.arange(vmin, vmax, interval)
+        bounds = np.sort(np.arange(vmax, vmin, -interval))
         return bounds
 
     @staticmethod
@@ -1588,7 +1584,7 @@ class PlotErrorsOnMap(AbstractPlotClass):
             cmap = sns.color_palette("magma_r", as_cmap=True)
         return cmap
 
-    def plot(self, plot_data, error_metric, error_long_name, error_units, model_type, limits):
+    def plot(self, plot_data, error_metric, error_long_name, error_units, model_type, limits, ahead=None):
         import cartopy.crs as ccrs
         from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
         fig = plt.figure(figsize=(10, 5))
@@ -1606,7 +1602,8 @@ class PlotErrorsOnMap(AbstractPlotClass):
         cbar_label = f"{error_long_name} (in {error_units})" if error_units is not None else error_long_name
         plt.colorbar(cb, label=cbar_label)
         self._adjust_extent(ax)
-        plt.title(model_type)
+        title = model_type if ahead is None else f"{model_type} ({ahead}{self.sampling})"
+        plt.title(title)
         plt.tight_layout()
 
     @staticmethod
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 0ff09e92..0fb14f55 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -735,7 +735,7 @@ class PostProcessing(RunEnvironment):
                 for error_metric in self.errors.keys():
                     try:
                         PlotErrorsOnMap(self.test_data, self.errors[error_metric], error_metric,
-                                        plot_folder=self.plot_path)
+                                        plot_folder=self.plot_path, sampling=self._sampling)
                     except Exception as e:
                         logging.error(f"Could not create plot PlotErrorsOnMap for {error_metric} due to the following "
                                       f"error: {e}\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")
-- 
GitLab