diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py
index 39e20020f4f80a872428681d53e2ec9f1a3dd3f7..18466fd9ede1e666f52b49fa461585a7e38410dd 100644
--- a/mlair/data_handler/iterator.py
+++ b/mlair/data_handler/iterator.py
@@ -33,13 +33,18 @@ class StandardIterator(Iterator):
 
 class DataCollection(Iterable):
 
-    def __init__(self, collection: list = None):
+    def __init__(self, collection: list = None, name: str = None):
         if collection is None:
             collection = []
         assert isinstance(collection, list)
         self._collection = collection
         self._mapping = {}
         self._set_mapping()
+        self._name = name
+
+    @property
+    def name(self):
+        return self._name
 
     def __len__(self):
         return len(self._collection)
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 35ab70327d88b161ce23228f2c42ea5d906a3a30..aee724397a6c6e2c83d2990238035bcaec57d570 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -10,6 +10,7 @@ from typing import Dict, List, Tuple
 
 import matplotlib
 import matplotlib.patches as mpatches
+import matplotlib.lines as mlines
 import matplotlib.pyplot as plt
 import matplotlib.dates as mdates
 import numpy as np
@@ -119,10 +120,11 @@ class AbstractPlotClass:
         """
         Standard colors used for train-, val-, and test-sets during postprocessing
         """
-        colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"}  # hex code
+        colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"}  # hex code
         return colors
 
 
+
 @TimeTrackingWrapper
 class PlotMonthlySummary(AbstractPlotClass):
     """
@@ -240,7 +242,7 @@ class PlotStationMap(AbstractPlotClass):
         :width: 400
     """
 
-    def __init__(self, generators: Dict, plot_folder: str = "."):
+    def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"):
         """
         Set attributes and create plot.
 
@@ -248,11 +250,11 @@ class PlotStationMap(AbstractPlotClass):
         as value.
         :param plot_folder: path to save the plot (default: current directory)
         """
-        super().__init__(plot_folder, "station_map")
+        super().__init__(plot_folder, plot_name)
         self._ax = None
         self._gl = None
         self._plot(generators)
-        self._save()
+        self._save(bbox_inches="tight")
 
     def _draw_background(self):
         """Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
@@ -279,13 +281,44 @@ class PlotStationMap(AbstractPlotClass):
 
         import cartopy.crs as ccrs
         if generators is not None:
-            for color, data_collection in generators.items():
+            legend_elements = []
+            default_colors = self.get_dataset_colors()
+            for element in generators:
+                data_collection, plot_opts = self._get_collection_and_opts(element)
+                name = data_collection.name or "unknown"
+                marker = plot_opts.get("marker", "s")
+                ms = plot_opts.get("ms", 6)
+                mec = plot_opts.get("mec", "k")
+                mfc = plot_opts.get("mfc", default_colors.get(name, "b"))
+                legend_elements.append(
+                    mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None',
+                                  label=name))
                 for station in data_collection:
                     coords = station.get_coordinates()
                     IDx, IDy = coords["lon"], coords["lat"]
-                    self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())
+                    self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
+            if len(legend_elements) > 0:
+                self._ax.legend(handles=legend_elements, loc='best')
+
+    @staticmethod
+    def _adjust_marker(marker):
+        _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"}
+        if isinstance(marker, int) and marker in _adjust.keys():
+            return _adjust[marker]
+        else:
+            return marker
+
+    @staticmethod
+    def _get_collection_and_opts(element):
+        if isinstance(element, tuple):
+            if len(element) == 1:
+                return element[0], {}
+            else:
+                return element
+        else:
+            return element, {}
 
-    def _plot(self, generators: Dict):
+    def _plot(self, generators: List):
         """
         Create the station map plot.
 
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index d125474e2224d4311137702c2796bd89b9f198ee..cb24ca3cf14f1c04a99af65be15edc7151478878 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -277,7 +277,13 @@ class PostProcessing(RunEnvironment):
                 logging.warning(
                     f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}")
             else:
-                PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path)
+                gens = [(self.train_data, {"marker": 5, "ms": 9}),
+                        (self.val_data, {"marker": 6, "ms": 9}),
+                        (self.test_data, {"marker": 4, "ms": 9})]
+                PlotStationMap(generators=gens, plot_folder=self.plot_path)
+                gens = [(self.train_val_data, {"marker": 8, "ms": 9}),
+                        (self.test_data, {"marker": 9, "ms": 9})]
+                PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var")
         if "PlotMonthlySummary" in plot_list:
             PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var,
                                plot_folder=self.plot_path)
diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py
index c9e92e7deb22dcff12e9d4ab982f14289f764a97..bd3b9ec6fd471dcb6e794a3ba9b498e18ad76a37 100644
--- a/mlair/run_modules/pre_processing.py
+++ b/mlair/run_modules/pre_processing.py
@@ -257,7 +257,7 @@ class PreProcessing(RunEnvironment):
             logging.info("setup transformation using train data exclusively")
             self.transformation(data_handler, set_stations)
         # start station check
-        collection = DataCollection()
+        collection = DataCollection(name=set_name)
         valid_stations = []
         kwargs = self.data_store.create_args_dict(data_handler.requirements(), scope=set_name)