diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index ac20c6f8084a04c8c0f99f9bd115888fc0d6d661..c46911c3200c2a34622c8d9b5445597473869066 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -124,6 +124,7 @@ class AbstractPlotClass: return colors + @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): """ @@ -281,15 +282,17 @@ class PlotStationMap(AbstractPlotClass): import cartopy.crs as ccrs if generators is not None: legend_elements = [] + default_colors = self.get_dataset_colors() for element in generators: data_collection, plot_opts = self._get_collection_and_opts(element) + name = data_collection.name or "unknown" marker = plot_opts.get("marker", "s") ms = plot_opts.get("ms", 6) mec = plot_opts.get("mec", "k") - mfc = plot_opts.get("mfc", "b") - name = data_collection.name or "unknown" + mfc = plot_opts.get("mfc", default_colors.get(name, "b")) legend_elements.append( - mlines.Line2D([], [], mfc=mfc, mec=mec, marker=marker, ms=ms, linestyle='None', label=name)) + mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', + label=name)) for station in data_collection: coords = station.get_coordinates() IDx, IDy = coords["lon"], coords["lat"] @@ -297,6 +300,14 @@ class PlotStationMap(AbstractPlotClass): if len(legend_elements) > 0: self._ax.legend(handles=legend_elements, loc='best') + @staticmethod + def _adjust_marker(marker): + _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} + if isinstance(marker, int) and marker in _adjust.keys(): + return _adjust[marker] + else: + return marker + @staticmethod def _get_collection_and_opts(element): if isinstance(element, tuple): diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 300d2e1dbd2ded279ee289ee64a2acd5f5c36fbc..79c0e3bf8eb8eb9ef7550984bdf6d92715a9b527 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -277,8 +277,9 @@ class PostProcessing(RunEnvironment): logging.warning( f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") else: - gens = [(self.train_val_data, {"mfc": "r", "marker": 8, "ms": 10}), - (self.test_data, {"mfc": "b", "marker": 9, "ms": 10})] + gens = [(self.train_data, {"marker": 5, "ms": 9}), + (self.val_data, {"marker": 6, "ms": 9}), + (self.test_data, {"marker": 4, "ms": 9})] PlotStationMap(generators=gens, plot_folder=self.plot_path) if "PlotMonthlySummary" in plot_list: PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var,