From 3e392698f003fee517a015109983d9dd6fefa0bb Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 8 Dec 2020 16:05:07 +0100
Subject: [PATCH] more plot options are available form outside for the
 PlotStationMap, during post processing the subsets train_val and test are
 indicated by different colors and markers

---
 mlair/plotting/postprocessing_plotting.py | 23 +++++++++++++++++++----
 mlair/run_modules/post_processing.py      |  4 +++-
 2 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 35ab7032..dc332190 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -240,7 +240,7 @@ class PlotStationMap(AbstractPlotClass):
         :width: 400
     """
 
-    def __init__(self, generators: Dict, plot_folder: str = "."):
+    def __init__(self, generators: List, plot_folder: str = "."):
         """
         Set attributes and create plot.
 
@@ -279,13 +279,28 @@ class PlotStationMap(AbstractPlotClass):
 
         import cartopy.crs as ccrs
         if generators is not None:
-            for color, data_collection in generators.items():
+            for element in generators:
+                data_collection, plot_opts = self._get_collection_and_opts(element)
+                marker = plot_opts.get("marker", "s")
+                ms = plot_opts.get("markersize", 6)
+                mec = plot_opts.get("mec", "k")
+                mfc = plot_opts.get("mfc", "b")
                 for station in data_collection:
                     coords = station.get_coordinates()
                     IDx, IDy = coords["lon"], coords["lat"]
-                    self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())
+                    self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
 
-    def _plot(self, generators: Dict):
+    @staticmethod
+    def _get_collection_and_opts(element):
+        if isinstance(element, tuple):
+            if len(element) == 1:
+                return element[0], {}
+            else:
+                return element
+        else:
+            return element, {}
+
+    def _plot(self, generators: List):
         """
         Create the station map plot.
 
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index d125474e..0b9393e0 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -277,7 +277,9 @@ class PostProcessing(RunEnvironment):
                 logging.warning(
                     f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}")
             else:
-                PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
+                gens = [(self.train_val_data, {"mfc": "r", "marker": 8}),
+                        (self.test_data, {"mfc": "b", "marker": 9})]
+                PlotStationMap(generators=gens, plot_folder=self.plot_path)
         if "PlotMonthlySummary" in plot_list:
             PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var,
                                plot_folder=self.plot_path)
-- 
GitLab