diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py index 39e20020f4f80a872428681d53e2ec9f1a3dd3f7..18466fd9ede1e666f52b49fa461585a7e38410dd 100644 --- a/mlair/data_handler/iterator.py +++ b/mlair/data_handler/iterator.py @@ -33,13 +33,18 @@ class StandardIterator(Iterator): class DataCollection(Iterable): - def __init__(self, collection: list = None): + def __init__(self, collection: list = None, name: str = None): if collection is None: collection = [] assert isinstance(collection, list) self._collection = collection self._mapping = {} self._set_mapping() + self._name = name + + @property + def name(self): + return self._name def __len__(self): return len(self._collection) diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 35ab70327d88b161ce23228f2c42ea5d906a3a30..aee724397a6c6e2c83d2990238035bcaec57d570 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -10,6 +10,7 @@ from typing import Dict, List, Tuple import matplotlib import matplotlib.patches as mpatches +import matplotlib.lines as mlines import matplotlib.pyplot as plt import matplotlib.dates as mdates import numpy as np @@ -119,10 +120,11 @@ class AbstractPlotClass: """ Standard colors used for train-, val-, and test-sets during postprocessing """ - colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code + colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code return colors + @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): """ @@ -240,7 +242,7 @@ class PlotStationMap(AbstractPlotClass): :width: 400 """ - def __init__(self, generators: Dict, plot_folder: str = "."): + def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"): """ Set attributes and create plot. @@ -248,11 +250,11 @@ class PlotStationMap(AbstractPlotClass): as value. :param plot_folder: path to save the plot (default: current directory) """ - super().__init__(plot_folder, "station_map") + super().__init__(plot_folder, plot_name) self._ax = None self._gl = None self._plot(generators) - self._save() + self._save(bbox_inches="tight") def _draw_background(self): """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" @@ -279,13 +281,44 @@ class PlotStationMap(AbstractPlotClass): import cartopy.crs as ccrs if generators is not None: - for color, data_collection in generators.items(): + legend_elements = [] + default_colors = self.get_dataset_colors() + for element in generators: + data_collection, plot_opts = self._get_collection_and_opts(element) + name = data_collection.name or "unknown" + marker = plot_opts.get("marker", "s") + ms = plot_opts.get("ms", 6) + mec = plot_opts.get("mec", "k") + mfc = plot_opts.get("mfc", default_colors.get(name, "b")) + legend_elements.append( + mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', + label=name)) for station in data_collection: coords = station.get_coordinates() IDx, IDy = coords["lon"], coords["lat"] - self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree()) + self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) + if len(legend_elements) > 0: + self._ax.legend(handles=legend_elements, loc='best') + + @staticmethod + def _adjust_marker(marker): + _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} + if isinstance(marker, int) and marker in _adjust.keys(): + return _adjust[marker] + else: + return marker + + @staticmethod + def _get_collection_and_opts(element): + if isinstance(element, tuple): + if len(element) == 1: + return element[0], {} + else: + return element + else: + return element, {} - def _plot(self, generators: Dict): + def _plot(self, generators: List): """ Create the station map plot. diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index d125474e2224d4311137702c2796bd89b9f198ee..cb24ca3cf14f1c04a99af65be15edc7151478878 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -277,7 +277,13 @@ class PostProcessing(RunEnvironment): logging.warning( f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") else: - PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) + gens = [(self.train_data, {"marker": 5, "ms": 9}), + (self.val_data, {"marker": 6, "ms": 9}), + (self.test_data, {"marker": 4, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path) + gens = [(self.train_val_data, {"marker": 8, "ms": 9}), + (self.test_data, {"marker": 9, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") if "PlotMonthlySummary" in plot_list: PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var, plot_folder=self.plot_path) diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index c9e92e7deb22dcff12e9d4ab982f14289f764a97..bd3b9ec6fd471dcb6e794a3ba9b498e18ad76a37 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -257,7 +257,7 @@ class PreProcessing(RunEnvironment): logging.info("setup transformation using train data exclusively") self.transformation(data_handler, set_stations) # start station check - collection = DataCollection() + collection = DataCollection(name=set_name) valid_stations = [] kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name)