From 3e392698f003fee517a015109983d9dd6fefa0bb Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 8 Dec 2020 16:05:07 +0100 Subject: [PATCH] more plot options are available form outside for the PlotStationMap, during post processing the subsets train_val and test are indicated by different colors and markers --- mlair/plotting/postprocessing_plotting.py | 23 +++++++++++++++++++---- mlair/run_modules/post_processing.py | 4 +++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 35ab7032..dc332190 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -240,7 +240,7 @@ class PlotStationMap(AbstractPlotClass): :width: 400 """ - def __init__(self, generators: Dict, plot_folder: str = "."): + def __init__(self, generators: List, plot_folder: str = "."): """ Set attributes and create plot. @@ -279,13 +279,28 @@ class PlotStationMap(AbstractPlotClass): import cartopy.crs as ccrs if generators is not None: - for color, data_collection in generators.items(): + for element in generators: + data_collection, plot_opts = self._get_collection_and_opts(element) + marker = plot_opts.get("marker", "s") + ms = plot_opts.get("markersize", 6) + mec = plot_opts.get("mec", "k") + mfc = plot_opts.get("mfc", "b") for station in data_collection: coords = station.get_coordinates() IDx, IDy = coords["lon"], coords["lat"] - self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree()) + self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) - def _plot(self, generators: Dict): + @staticmethod + def _get_collection_and_opts(element): + if isinstance(element, tuple): + if len(element) == 1: + return element[0], {} + else: + return element + else: + return element, {} + + def _plot(self, generators: List): """ Create the station map plot. diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index d125474e..0b9393e0 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -277,7 +277,9 @@ class PostProcessing(RunEnvironment): logging.warning( f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") else: - PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) + gens = [(self.train_val_data, {"mfc": "r", "marker": 8}), + (self.test_data, {"mfc": "b", "marker": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path) if "PlotMonthlySummary" in plot_list: PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var, plot_folder=self.plot_path) -- GitLab