From d68daa628db70cf0ff0cbe4e786832fc1c93e5a2 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 8 Dec 2020 17:15:03 +0100 Subject: [PATCH] changed colors for station map to default colors --- mlair/plotting/postprocessing_plotting.py | 17 ++++++++++++++--- mlair/run_modules/post_processing.py | 5 +++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index ac20c6f8..c46911c3 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -124,6 +124,7 @@ class AbstractPlotClass: return colors + @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): """ @@ -281,15 +282,17 @@ class PlotStationMap(AbstractPlotClass): import cartopy.crs as ccrs if generators is not None: legend_elements = [] + default_colors = self.get_dataset_colors() for element in generators: data_collection, plot_opts = self._get_collection_and_opts(element) + name = data_collection.name or "unknown" marker = plot_opts.get("marker", "s") ms = plot_opts.get("ms", 6) mec = plot_opts.get("mec", "k") - mfc = plot_opts.get("mfc", "b") - name = data_collection.name or "unknown" + mfc = plot_opts.get("mfc", default_colors.get(name, "b")) legend_elements.append( - mlines.Line2D([], [], mfc=mfc, mec=mec, marker=marker, ms=ms, linestyle='None', label=name)) + mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', + label=name)) for station in data_collection: coords = station.get_coordinates() IDx, IDy = coords["lon"], coords["lat"] @@ -297,6 +300,14 @@ class PlotStationMap(AbstractPlotClass): if len(legend_elements) > 0: self._ax.legend(handles=legend_elements, loc='best') + @staticmethod + def _adjust_marker(marker): + _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} + if isinstance(marker, int) and marker in _adjust.keys(): + return _adjust[marker] + else: + return marker + @staticmethod def _get_collection_and_opts(element): if isinstance(element, tuple): diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 300d2e1d..79c0e3bf 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -277,8 +277,9 @@ class PostProcessing(RunEnvironment): logging.warning( f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") else: - gens = [(self.train_val_data, {"mfc": "r", "marker": 8, "ms": 10}), - (self.test_data, {"mfc": "b", "marker": 9, "ms": 10})] + gens = [(self.train_data, {"marker": 5, "ms": 9}), + (self.val_data, {"marker": 6, "ms": 9}), + (self.test_data, {"marker": 4, "ms": 9})] PlotStationMap(generators=gens, plot_folder=self.plot_path) if "PlotMonthlySummary" in plot_list: PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var, -- GitLab