From 317f24468d546fc91a9801bc4f07496d6d31d0aa Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 13 Apr 2021 11:49:37 +0200
Subject: [PATCH] split plots into abstract, pre, and postprocessing plots

---
 mlair/plotting/abstract_plot_class.py     | 101 +++++
 mlair/plotting/postprocessing_plotting.py | 523 +---------------------
 mlair/plotting/preprocessing_plotting.py  | 438 ++++++++++++++++++
 mlair/run_modules/post_processing.py      |   6 +-
 4 files changed, 543 insertions(+), 525 deletions(-)
 create mode 100644 mlair/plotting/abstract_plot_class.py
 create mode 100644 mlair/plotting/preprocessing_plotting.py

diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py
new file mode 100644
index 00000000..dab45156
--- /dev/null
+++ b/mlair/plotting/abstract_plot_class.py
@@ -0,0 +1,101 @@
+"""Abstract plot class that should be used for preprocessing and postprocessing plots."""
+__author__ = "Lukas Leufen"
+__date__ = '2021-04-13'
+
+import logging
+import os
+
+from matplotlib import pyplot as plt
+
+
+class AbstractPlotClass:
+    """
+    Abstract class for all plotting routines to unify plot workflow.
+
+    Each inheritance requires a _plot method. Create a plot class like:
+
+    .. code-block:: python
+
+        class MyCustomPlot(AbstractPlotClass):
+
+            def __init__(self, plot_folder, *args, **kwargs):
+                super().__init__(plot_folder, "custom_plot_name")
+                self._data = self._prepare_data(*args, **kwargs)
+                self._plot(*args, **kwargs)
+                self._save()
+
+            def _prepare_data(*args, **kwargs):
+                <your custom data preparation>
+                return data
+
+            def _plot(*args, **kwargs):
+                <your custom plotting without saving>
+
+    The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are
+    using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be
+    set in super class initialisation).
+
+    Methods like the shown _prepare_data() are optional. The only method required to implement is _plot.
+
+    If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot
+    class. It will log the spent time if you call your plotting without saving the returned object.
+
+    .. code-block:: python
+
+        @TimeTrackingWrapper
+        class MyCustomPlot(AbstractPlotClass):
+            pass
+
+    Let's assume it takes a while to create this very special plot.
+
+    >>> MyCustomPlot()
+    INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss)
+
+    """
+
+    def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None):
+        """Set up plot folder and name, and plot resolution (default 500dpi)."""
+        plot_folder = os.path.abspath(plot_folder)
+        if not os.path.exists(plot_folder):
+            os.makedirs(plot_folder)
+        self.plot_folder = plot_folder
+        self.plot_name = plot_name
+        self.resolution = resolution
+        if rc_params is None:
+            rc_params = {'axes.labelsize': 'large',
+                         'xtick.labelsize': 'large',
+                         'ytick.labelsize': 'large',
+                         'legend.fontsize': 'large',
+                         'axes.titlesize': 'large',
+                         }
+        self.rc_params = rc_params
+        self._update_rc_params()
+
+    def _plot(self, *args):
+        """Abstract plot class needs to be implemented in inheritance."""
+        raise NotImplementedError
+
+    def _save(self, **kwargs):
+        """Store plot locally. Name of and path to plot need to be set on initialisation."""
+        plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
+        logging.debug(f"... save plot to {plot_name}")
+        plt.savefig(plot_name, dpi=self.resolution, **kwargs)
+        plt.close('all')
+
+    def _update_rc_params(self):
+        plt.rcParams.update(self.rc_params)
+
+    @staticmethod
+    def _get_sampling(sampling):
+        if sampling == "daily":
+            return "D"
+        elif sampling == "hourly":
+            return "h"
+
+    @staticmethod
+    def get_dataset_colors():
+        """
+        Standard colors used for train-, val-, and test-sets during postprocessing
+        """
+        colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"}  # hex code
+        return colors
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index d769fabc..491aa52e 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -9,10 +9,7 @@ import warnings
 from typing import Dict, List, Tuple
 
 import matplotlib
-import matplotlib.patches as mpatches
-import matplotlib.lines as mlines
 import matplotlib.pyplot as plt
-import matplotlib.dates as mdates
 import numpy as np
 import pandas as pd
 import seaborn as sns
@@ -22,6 +19,7 @@ from matplotlib.backends.backend_pdf import PdfPages
 from mlair import helpers
 from mlair.data_handler.iterator import DataCollection
 from mlair.helpers import TimeTrackingWrapper
+from mlair.plotting.abstract_plot_class import AbstractPlotClass
 
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
@@ -31,100 +29,6 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
 # import matplotlib.pyplot as plt
 
 
-class AbstractPlotClass:
-    """
-    Abstract class for all plotting routines to unify plot workflow.
-
-    Each inheritance requires a _plot method. Create a plot class like:
-
-    .. code-block:: python
-
-        class MyCustomPlot(AbstractPlotClass):
-
-            def __init__(self, plot_folder, *args, **kwargs):
-                super().__init__(plot_folder, "custom_plot_name")
-                self._data = self._prepare_data(*args, **kwargs)
-                self._plot(*args, **kwargs)
-                self._save()
-
-            def _prepare_data(*args, **kwargs):
-                <your custom data preparation>
-                return data
-
-            def _plot(*args, **kwargs):
-                <your custom plotting without saving>
-
-    The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are
-    using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be
-    set in super class initialisation).
-
-    Methods like the shown _prepare_data() are optional. The only method required to implement is _plot.
-
-    If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot
-    class. It will log the spent time if you call your plotting without saving the returned object.
-
-    .. code-block:: python
-
-        @TimeTrackingWrapper
-        class MyCustomPlot(AbstractPlotClass):
-            pass
-
-    Let's assume it takes a while to create this very special plot.
-
-    >>> MyCustomPlot()
-    INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss)
-
-    """
-
-    def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None):
-        """Set up plot folder and name, and plot resolution (default 500dpi)."""
-        plot_folder = os.path.abspath(plot_folder)
-        if not os.path.exists(plot_folder):
-            os.makedirs(plot_folder)
-        self.plot_folder = plot_folder
-        self.plot_name = plot_name
-        self.resolution = resolution
-        if rc_params is None:
-            rc_params = {'axes.labelsize': 'large',
-                         'xtick.labelsize': 'large',
-                         'ytick.labelsize': 'large',
-                         'legend.fontsize': 'large',
-                         'axes.titlesize': 'large',
-                         }
-        self.rc_params = rc_params
-        self._update_rc_params()
-
-    def _plot(self, *args):
-        """Abstract plot class needs to be implemented in inheritance."""
-        raise NotImplementedError
-
-    def _save(self, **kwargs):
-        """Store plot locally. Name of and path to plot need to be set on initialisation."""
-        plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
-        logging.debug(f"... save plot to {plot_name}")
-        plt.savefig(plot_name, dpi=self.resolution, **kwargs)
-        plt.close('all')
-
-    def _update_rc_params(self):
-        plt.rcParams.update(self.rc_params)
-
-    @staticmethod
-    def _get_sampling(sampling):
-        if sampling == "daily":
-            return "D"
-        elif sampling == "hourly":
-            return "h"
-
-    @staticmethod
-    def get_dataset_colors():
-        """
-        Standard colors used for train-, val-, and test-sets during postprocessing
-        """
-        colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"}  # hex code
-        return colors
-
-
-
 @TimeTrackingWrapper
 class PlotMonthlySummary(AbstractPlotClass):
     """
@@ -230,132 +134,6 @@ class PlotMonthlySummary(AbstractPlotClass):
         plt.tight_layout()
 
 
-@TimeTrackingWrapper
-class PlotStationMap(AbstractPlotClass):
-    """
-    Plot geographical overview of all used stations as squares.
-
-    Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to
-    plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored
-    topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf
-
-    .. image:: ../../../../../_source/_plots/station_map.png
-        :width: 400
-    """
-
-    def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"):
-        """
-        Set attributes and create plot.
-
-        :param generators: dictionary with the plot color of each data set as key and the generator containing all stations
-        as value.
-        :param plot_folder: path to save the plot (default: current directory)
-        """
-        super().__init__(plot_folder, plot_name)
-        self._ax = None
-        self._gl = None
-        self._plot(generators)
-        self._save(bbox_inches="tight")
-
-    def _draw_background(self):
-        """Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
-
-        import cartopy.feature as cfeature
-
-        self._ax.add_feature(cfeature.LAND.with_scale("50m"))
-        self._ax.natural_earth_shp(resolution='50m')
-        self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
-        self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
-        self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
-        self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
-        self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
-
-    def _plot_stations(self, generators):
-        """
-        Loop over all keys in generators dict and its containing stations and plot the stations's position.
-
-        Position is highlighted by a square on the map regarding the given color.
-
-        :param generators: dictionary with the plot color of each data set as key and the generator containing all
-            stations as value.
-        """
-
-        import cartopy.crs as ccrs
-        if generators is not None:
-            legend_elements = []
-            default_colors = self.get_dataset_colors()
-            for element in generators:
-                data_collection, plot_opts = self._get_collection_and_opts(element)
-                name = data_collection.name or "unknown"
-                marker = plot_opts.get("marker", "s")
-                ms = plot_opts.get("ms", 6)
-                mec = plot_opts.get("mec", "k")
-                mfc = plot_opts.get("mfc", default_colors.get(name, "b"))
-                legend_elements.append(
-                    mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None',
-                                  label=f"{name} ({len(data_collection)})"))
-                for station in data_collection:
-                    coords = station.get_coordinates()
-                    IDx, IDy = coords["lon"], coords["lat"]
-                    self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
-            if len(legend_elements) > 0:
-                self._ax.legend(handles=legend_elements, loc='best')
-
-    @staticmethod
-    def _adjust_marker(marker):
-        _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"}
-        if isinstance(marker, int) and marker in _adjust.keys():
-            return _adjust[marker]
-        else:
-            return marker
-
-    @staticmethod
-    def _get_collection_and_opts(element):
-        if isinstance(element, tuple):
-            if len(element) == 1:
-                return element[0], {}
-            else:
-                return element
-        else:
-            return element, {}
-
-    def _plot(self, generators: List):
-        """
-        Create the station map plot.
-
-        Set figure and call all required sub-methods.
-
-        :param generators: dictionary with the plot color of each data set as key and the generator containing all
-            stations as value.
-        """
-
-        import cartopy.crs as ccrs
-        from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
-        fig = plt.figure(figsize=(10, 5))
-        self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
-        self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
-        self._gl.xformatter = LONGITUDE_FORMATTER
-        self._gl.yformatter = LATITUDE_FORMATTER
-        self._draw_background()
-        self._plot_stations(generators)
-        self._adjust_extent()
-        plt.tight_layout()
-
-    def _adjust_extent(self):
-        import cartopy.crs as ccrs
-
-        def diff(arr):
-            return arr[1] - arr[0], arr[3] - arr[2]
-
-        def find_ratio(delta, reference=5):
-            return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
-
-        extent = self._ax.get_extent(crs=ccrs.PlateCarree())
-        ratio = find_ratio(diff(extent))
-        new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
-        self._ax.set_extent(new_extent, crs=ccrs.PlateCarree())
-
-
 @TimeTrackingWrapper
 class PlotConditionalQuantiles(AbstractPlotClass):
     """
@@ -1138,133 +916,6 @@ class PlotTimeSeries:
         return matplotlib.backends.backend_pdf.PdfPages(plot_name)
 
 
-@TimeTrackingWrapper
-class PlotAvailability(AbstractPlotClass):
-    """
-    Create data availablility plot similar to Gantt plot.
-
-    Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal
-    resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a
-    colored bar or a blank space.
-
-    You can set different colors to highlight subsets for example by providing different generators for the same index
-    using different keys in the input dictionary.
-
-    Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs
-    in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset.
-
-    Calling this class will create three versions fo the availability plot.
-
-    1) Data availability for each element
-    1) Data availability as summary over all elements (is there at least a single elemnt for each time step)
-    1) Combination of single and overall availability
-
-    .. image:: ../../../../../_source/_plots/data_availability.png
-        :width: 400
-
-    .. image:: ../../../../../_source/_plots/data_availability_summary.png
-        :width: 400
-
-    .. image:: ../../../../../_source/_plots/data_availability_combined.png
-        :width: 400
-
-    """
-
-    def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
-                 summary_name="data availability", time_dimension="datetime", window_dimension="window"):
-        """Initialise."""
-        # create standard Gantt plot for all stations (currently in single pdf file with single page)
-        super().__init__(plot_folder, "data_availability")
-        self.time_dim = time_dimension
-        self.window_dim = window_dimension
-        self.sampling = self._get_sampling(sampling)
-        self.linewidth = None
-        if self.sampling == 'h':
-            self.linewidth = 0.001
-        plot_dict = self._prepare_data(generators)
-        lgd = self._plot(plot_dict)
-        self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
-        # create summary Gantt plot (is data in at least one station available)
-        self.plot_name += "_summary"
-        plot_dict_summary = self._summarise_data(generators, summary_name)
-        lgd = self._plot(plot_dict_summary)
-        self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
-        # combination of station and summary plot, last element is summary broken bar
-        self.plot_name = "data_availability_combined"
-        plot_dict_summary.update(plot_dict)
-        lgd = self._plot(plot_dict_summary)
-        self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
-
-    def _prepare_data(self, generators: Dict[str, DataCollection]):
-        plt_dict = {}
-        for subset, data_collection in generators.items():
-            for station in data_collection:
-                labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
-                labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
-                group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum()
-                plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
-                                         index=labels.coords[self.time_dim].values)
-                t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
-                t2 = [i[1:] for i in t if i[0]]
-
-                if plt_dict.get(str(station)) is None:
-                    plt_dict[str(station)] = {subset: t2}
-                else:
-                    plt_dict[str(station)].update({subset: t2})
-        return plt_dict
-
-    def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
-        plt_dict = {}
-        for subset, data_collection in generators.items():
-            all_data = None
-            for station in data_collection:
-                labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
-                labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
-                if all_data is None:
-                    all_data = labels_bool
-                else:
-                    tmp = all_data.combine_first(labels_bool)  # expand dims to merged datetime coords
-                    all_data = np.logical_or(tmp, labels_bool).combine_first(
-                        all_data)  # apply logical on merge and fill missing with all_data
-
-            group = (all_data != all_data.shift({self.time_dim: 1})).cumsum()
-            plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
-                                     index=all_data.coords[self.time_dim].values)
-            t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
-            t2 = [i[1:] for i in t if i[0]]
-            if plt_dict.get(summary_name) is None:
-                plt_dict[summary_name] = {subset: t2}
-            else:
-                plt_dict[summary_name].update({subset: t2})
-        return plt_dict
-
-    def _plot(self, plt_dict):
-        colors = self.get_dataset_colors()
-        _used_colors = []
-        pos = 0
-        height = 0.8  # should be <= 1
-        yticklabels = []
-        number_of_stations = len(plt_dict.keys())
-        fig, ax = plt.subplots(figsize=(10, number_of_stations / 3))
-        for station, d in sorted(plt_dict.items(), reverse=True):
-            pos += 1
-            for subset, color in colors.items():
-                plt_data = d.get(subset)
-                if plt_data is None:
-                    continue
-                elif color not in _used_colors:  # this is required for a proper legend creation
-                    _used_colors.append(color)
-                ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth)
-            yticklabels.append(station)
-
-        ax.set_ylim([height, number_of_stations + 1])
-        ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2)
-        ax.set_yticklabels(yticklabels)
-        handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors]
-        lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
-        return lgd
-
-
 @TimeTrackingWrapper
 class PlotSeparationOfScales(AbstractPlotClass):
 
@@ -1292,178 +943,6 @@ class PlotSeparationOfScales(AbstractPlotClass):
             self._save()
 
 
-@TimeTrackingWrapper
-class PlotAvailabilityHistogram(AbstractPlotClass):
-    """
-    Create data availability plots as histogram.
-
-    Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean).
-    Calling this class creates two different types of histograms where each generator
-
-    1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis)
-    2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number
-       of samples (yaxis)
-
-    .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png
-        :width: 400
-
-    .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png
-        :width: 400
-
-    """
-
-    def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".",
-                 subset_dim: str = 'DataSet', history_dim: str = 'window',
-                 station_dim: str = 'Stations',):
-
-        super().__init__(plot_folder, "data_availability_histogram")
-
-        self.subset_dim = subset_dim
-        self.history_dim = history_dim
-        self.station_dim = station_dim
-
-        self.freq = None
-        self.temporal_dim = None
-        self.target_dim = None
-        self._prepare_data(generators)
-
-        for plt_type in self.allowed_plot_types:
-            plot_name_tmp = self.plot_name
-            self.plot_name += '_' + plt_type
-            self._plot(plt_type=plt_type)
-            self._save()
-            self.plot_name = plot_name_tmp
-
-    def _set_dims_from_datahandler(self, data_handler):
-        self.temporal_dim = data_handler.id_class.time_dim
-        self.target_dim = data_handler.id_class.target_dim
-        self.freq = self._get_sampling(data_handler.id_class.sampling)
-
-    @property
-    def allowed_plot_types(self):
-        plot_types = ['hist', 'hist_cum']
-        return plot_types
-
-    def _prepare_data(self, generators: Dict[str, DataCollection]):
-        """
-        Prepares data to be used by plot methods.
-
-        Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim
-        """
-        avail_data_time_sum = {}
-        avail_data_station_sum = {}
-        dataset_time_interval = {}
-        for subset, generator in generators.items():
-            avail_list = []
-            for station in generator:
-                self._set_dims_from_datahandler(data_handler=station)
-                station_data_x = station.get_X(as_numpy=False)[0]
-                station_data_x = station_data_x.loc[{self.history_dim: 0,  # select recent window frame
-                                                     self.target_dim: station_data_x[self.target_dim].values[0]}]
-                station_data_x = self._reduce_dims(station_data_x)
-                avail_list.append(station_data_x.notnull())
-            avail_data = xr.concat(avail_list, dim=self.station_dim).notnull()
-            avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim)
-            avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim)
-            dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray(
-                avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict'
-            )
-        avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(),
-                                                                             name=self.subset_dim)
-                                      )
-        full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq)
-        self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(),
-                                                                                      name=self.subset_dim))
-        self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index})
-        self.dataset_time_interval = dataset_time_interval
-
-    def _reduce_dims(self, dataset):
-        if len(dataset.dims) > 2:
-            required = {self.temporal_dim, self.station_dim}
-            unimportant = set(dataset.dims).difference(required)
-            sel_dict = {un: dataset[un].values[0] for un in unimportant}
-            dataset = dataset.loc[sel_dict]
-        return dataset
-
-    @staticmethod
-    def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'):
-        if isinstance(xarray, xr.DataArray):
-            first = xarray.coords[dim_name].values[0]
-            last = xarray.coords[dim_name].values[-1]
-            if return_type == 'as_tuple':
-                return first, last
-            elif return_type == 'as_dict':
-                return {'first': first, 'last': last}
-            else:
-                raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'")
-        else:
-            raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}")
-
-    @staticmethod
-    def _make_full_time_index(irregular_time_index, freq):
-        full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq)
-        return full_time_index
-
-    def _plot(self, plt_type='hist', *args):
-        if plt_type == 'hist':
-            self._plot_hist()
-        elif plt_type == 'hist_cum':
-            self._plot_hist_cum()
-        else:
-            raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}")
-
-    def _plot_hist(self, *args):
-        colors = self.get_dataset_colors()
-        fig, axes = plt.subplots(figsize=(10, 3))
-        for i, subset in enumerate(self.dataset_time_interval.keys()):
-            plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset,
-                                                       self.temporal_dim: slice(
-                                                           self.dataset_time_interval[subset]['first'],
-                                                           self.dataset_time_interval[subset]['last']
-                                                       )
-                                                       }
-                                                      )
-
-            plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset)
-            plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset])
-
-        lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
-                         facecolor='white', framealpha=1, edgecolor='black')
-        for lgd_line in lgd.get_lines():
-            lgd_line.set_linewidth(4.0)
-        plt.gca().xaxis.set_major_locator(mdates.YearLocator())
-        plt.title('')
-        plt.ylabel('Number of samples')
-        plt.tight_layout()
-
-    def _plot_hist_cum(self, *args):
-        colors = self.get_dataset_colors()
-        fig, axes = plt.subplots(figsize=(10, 3))
-        n_bins = int(self.avail_data_cum_sum.max().values)
-        bins = np.arange(0, n_bins+1)
-        descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby(
-            self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False
-        ).coords[self.subset_dim].values
-
-        for subset in descending_subsets:
-            self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes,
-                                                                             bins=bins,
-                                                                             label=subset,
-                                                                             cumulative=-1,
-                                                                             color=colors[subset],
-                                                                             # alpha=.5
-                                                                             )
-
-        lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
-                         facecolor='white', framealpha=1, edgecolor='black')
-        plt.title('')
-        plt.ylabel('Number of stations')
-        plt.xlabel('Number of samples')
-        plt.xlim((bins[0], bins[-1]))
-        plt.tight_layout()
-
-
-
 if __name__ == "__main__":
     stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
     path = "../../testrun_network/forecasts"
diff --git a/mlair/plotting/preprocessing_plotting.py b/mlair/plotting/preprocessing_plotting.py
new file mode 100644
index 00000000..aa61b1f3
--- /dev/null
+++ b/mlair/plotting/preprocessing_plotting.py
@@ -0,0 +1,438 @@
+"""Collection of plots to get more insight into data."""
+__author__ = "Lukas Leufen, Felix Kleinert"
+__date__ = '2021-04-13'
+
+from typing import List, Dict
+
+import numpy as np
+import pandas as pd
+import xarray as xr
+from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates
+
+from mlair.data_handler import DataCollection
+from mlair.helpers import TimeTrackingWrapper
+from mlair.plotting.abstract_plot_class import AbstractPlotClass
+
+
+@TimeTrackingWrapper
+class PlotStationMap(AbstractPlotClass):
+    """
+    Plot geographical overview of all used stations as squares.
+
+    Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to
+    plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored
+    topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf
+
+    .. image:: ../../../../../_source/_plots/station_map.png
+        :width: 400
+    """
+
+    def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"):
+        """
+        Set attributes and create plot.
+
+        :param generators: dictionary with the plot color of each data set as key and the generator containing all stations
+        as value.
+        :param plot_folder: path to save the plot (default: current directory)
+        """
+        super().__init__(plot_folder, plot_name)
+        self._ax = None
+        self._gl = None
+        self._plot(generators)
+        self._save(bbox_inches="tight")
+
+    def _draw_background(self):
+        """Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
+
+        import cartopy.feature as cfeature
+
+        self._ax.add_feature(cfeature.LAND.with_scale("50m"))
+        self._ax.natural_earth_shp(resolution='50m')
+        self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
+        self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
+        self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
+        self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
+        self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
+
+    def _plot_stations(self, generators):
+        """
+        Loop over all keys in generators dict and its containing stations and plot the stations's position.
+
+        Position is highlighted by a square on the map regarding the given color.
+
+        :param generators: dictionary with the plot color of each data set as key and the generator containing all
+            stations as value.
+        """
+
+        import cartopy.crs as ccrs
+        if generators is not None:
+            legend_elements = []
+            default_colors = self.get_dataset_colors()
+            for element in generators:
+                data_collection, plot_opts = self._get_collection_and_opts(element)
+                name = data_collection.name or "unknown"
+                marker = plot_opts.get("marker", "s")
+                ms = plot_opts.get("ms", 6)
+                mec = plot_opts.get("mec", "k")
+                mfc = plot_opts.get("mfc", default_colors.get(name, "b"))
+                legend_elements.append(
+                    mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None',
+                                  label=f"{name} ({len(data_collection)})"))
+                for station in data_collection:
+                    coords = station.get_coordinates()
+                    IDx, IDy = coords["lon"], coords["lat"]
+                    self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
+            if len(legend_elements) > 0:
+                self._ax.legend(handles=legend_elements, loc='best')
+
+    @staticmethod
+    def _adjust_marker(marker):
+        _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"}
+        if isinstance(marker, int) and marker in _adjust.keys():
+            return _adjust[marker]
+        else:
+            return marker
+
+    @staticmethod
+    def _get_collection_and_opts(element):
+        if isinstance(element, tuple):
+            if len(element) == 1:
+                return element[0], {}
+            else:
+                return element
+        else:
+            return element, {}
+
+    def _plot(self, generators: List):
+        """
+        Create the station map plot.
+
+        Set figure and call all required sub-methods.
+
+        :param generators: dictionary with the plot color of each data set as key and the generator containing all
+            stations as value.
+        """
+
+        import cartopy.crs as ccrs
+        from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
+        fig = plt.figure(figsize=(10, 5))
+        self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
+        self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
+        self._gl.xformatter = LONGITUDE_FORMATTER
+        self._gl.yformatter = LATITUDE_FORMATTER
+        self._draw_background()
+        self._plot_stations(generators)
+        self._adjust_extent()
+        plt.tight_layout()
+
+    def _adjust_extent(self):
+        import cartopy.crs as ccrs
+
+        def diff(arr):
+            return arr[1] - arr[0], arr[3] - arr[2]
+
+        def find_ratio(delta, reference=5):
+            return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
+
+        extent = self._ax.get_extent(crs=ccrs.PlateCarree())
+        ratio = find_ratio(diff(extent))
+        new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
+        self._ax.set_extent(new_extent, crs=ccrs.PlateCarree())
+
+
+@TimeTrackingWrapper
+class PlotAvailability(AbstractPlotClass):
+    """
+    Create data availablility plot similar to Gantt plot.
+
+    Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal
+    resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a
+    colored bar or a blank space.
+
+    You can set different colors to highlight subsets for example by providing different generators for the same index
+    using different keys in the input dictionary.
+
+    Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs
+    in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset.
+
+    Calling this class will create three versions fo the availability plot.
+
+    1) Data availability for each element
+    1) Data availability as summary over all elements (is there at least a single elemnt for each time step)
+    1) Combination of single and overall availability
+
+    .. image:: ../../../../../_source/_plots/data_availability.png
+        :width: 400
+
+    .. image:: ../../../../../_source/_plots/data_availability_summary.png
+        :width: 400
+
+    .. image:: ../../../../../_source/_plots/data_availability_combined.png
+        :width: 400
+
+    """
+
+    def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
+                 summary_name="data availability", time_dimension="datetime", window_dimension="window"):
+        """Initialise."""
+        # create standard Gantt plot for all stations (currently in single pdf file with single page)
+        super().__init__(plot_folder, "data_availability")
+        self.time_dim = time_dimension
+        self.window_dim = window_dimension
+        self.sampling = self._get_sampling(sampling)
+        self.linewidth = None
+        if self.sampling == 'h':
+            self.linewidth = 0.001
+        plot_dict = self._prepare_data(generators)
+        lgd = self._plot(plot_dict)
+        self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
+        # create summary Gantt plot (is data in at least one station available)
+        self.plot_name += "_summary"
+        plot_dict_summary = self._summarise_data(generators, summary_name)
+        lgd = self._plot(plot_dict_summary)
+        self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
+        # combination of station and summary plot, last element is summary broken bar
+        self.plot_name = "data_availability_combined"
+        plot_dict_summary.update(plot_dict)
+        lgd = self._plot(plot_dict_summary)
+        self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
+
+    def _prepare_data(self, generators: Dict[str, DataCollection]):
+        plt_dict = {}
+        for subset, data_collection in generators.items():
+            for station in data_collection:
+                labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
+                labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
+                group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum()
+                plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
+                                         index=labels.coords[self.time_dim].values)
+                t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
+                t2 = [i[1:] for i in t if i[0]]
+
+                if plt_dict.get(str(station)) is None:
+                    plt_dict[str(station)] = {subset: t2}
+                else:
+                    plt_dict[str(station)].update({subset: t2})
+        return plt_dict
+
+    def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
+        plt_dict = {}
+        for subset, data_collection in generators.items():
+            all_data = None
+            for station in data_collection:
+                labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
+                labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
+                if all_data is None:
+                    all_data = labels_bool
+                else:
+                    tmp = all_data.combine_first(labels_bool)  # expand dims to merged datetime coords
+                    all_data = np.logical_or(tmp, labels_bool).combine_first(
+                        all_data)  # apply logical on merge and fill missing with all_data
+
+            group = (all_data != all_data.shift({self.time_dim: 1})).cumsum()
+            plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
+                                     index=all_data.coords[self.time_dim].values)
+            t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
+            t2 = [i[1:] for i in t if i[0]]
+            if plt_dict.get(summary_name) is None:
+                plt_dict[summary_name] = {subset: t2}
+            else:
+                plt_dict[summary_name].update({subset: t2})
+        return plt_dict
+
+    def _plot(self, plt_dict):
+        colors = self.get_dataset_colors()
+        _used_colors = []
+        pos = 0
+        height = 0.8  # should be <= 1
+        yticklabels = []
+        number_of_stations = len(plt_dict.keys())
+        fig, ax = plt.subplots(figsize=(10, number_of_stations / 3))
+        for station, d in sorted(plt_dict.items(), reverse=True):
+            pos += 1
+            for subset, color in colors.items():
+                plt_data = d.get(subset)
+                if plt_data is None:
+                    continue
+                elif color not in _used_colors:  # this is required for a proper legend creation
+                    _used_colors.append(color)
+                ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth)
+            yticklabels.append(station)
+
+        ax.set_ylim([height, number_of_stations + 1])
+        ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2)
+        ax.set_yticklabels(yticklabels)
+        handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors]
+        lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
+        return lgd
+
+
+@TimeTrackingWrapper
+class PlotAvailabilityHistogram(AbstractPlotClass):
+    """
+    Create data availability plots as histogram.
+
+    Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean).
+    Calling this class creates two different types of histograms where each generator
+
+    1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis)
+    2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number
+       of samples (yaxis)
+
+    .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png
+        :width: 400
+
+    .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png
+        :width: 400
+
+    """
+
+    def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".",
+                 subset_dim: str = 'DataSet', history_dim: str = 'window',
+                 station_dim: str = 'Stations', ):
+
+        super().__init__(plot_folder, "data_availability_histogram")
+
+        self.subset_dim = subset_dim
+        self.history_dim = history_dim
+        self.station_dim = station_dim
+
+        self.freq = None
+        self.temporal_dim = None
+        self.target_dim = None
+        self._prepare_data(generators)
+
+        for plt_type in self.allowed_plot_types:
+            plot_name_tmp = self.plot_name
+            self.plot_name += '_' + plt_type
+            self._plot(plt_type=plt_type)
+            self._save()
+            self.plot_name = plot_name_tmp
+
+    def _set_dims_from_datahandler(self, data_handler):
+        self.temporal_dim = data_handler.id_class.time_dim
+        self.target_dim = data_handler.id_class.target_dim
+        self.freq = self._get_sampling(data_handler.id_class.sampling)
+
+    @property
+    def allowed_plot_types(self):
+        plot_types = ['hist', 'hist_cum']
+        return plot_types
+
+    def _prepare_data(self, generators: Dict[str, DataCollection]):
+        """
+        Prepares data to be used by plot methods.
+
+        Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim
+        """
+        avail_data_time_sum = {}
+        avail_data_station_sum = {}
+        dataset_time_interval = {}
+        for subset, generator in generators.items():
+            avail_list = []
+            for station in generator:
+                self._set_dims_from_datahandler(data_handler=station)
+                station_data_x = station.get_X(as_numpy=False)[0]
+                station_data_x = station_data_x.loc[{self.history_dim: 0,  # select recent window frame
+                                                     self.target_dim: station_data_x[self.target_dim].values[0]}]
+                station_data_x = self._reduce_dims(station_data_x)
+                avail_list.append(station_data_x.notnull())
+            avail_data = xr.concat(avail_list, dim=self.station_dim).notnull()
+            avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim)
+            avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim)
+            dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray(
+                avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict'
+            )
+        avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(),
+                                                                             name=self.subset_dim)
+                                      )
+        full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq)
+        self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(),
+                                                                                      name=self.subset_dim))
+        self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index})
+        self.dataset_time_interval = dataset_time_interval
+
+    def _reduce_dims(self, dataset):
+        if len(dataset.dims) > 2:
+            required = {self.temporal_dim, self.station_dim}
+            unimportant = set(dataset.dims).difference(required)
+            sel_dict = {un: dataset[un].values[0] for un in unimportant}
+            dataset = dataset.loc[sel_dict]
+        return dataset
+
+    @staticmethod
+    def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'):
+        if isinstance(xarray, xr.DataArray):
+            first = xarray.coords[dim_name].values[0]
+            last = xarray.coords[dim_name].values[-1]
+            if return_type == 'as_tuple':
+                return first, last
+            elif return_type == 'as_dict':
+                return {'first': first, 'last': last}
+            else:
+                raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'")
+        else:
+            raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}")
+
+    @staticmethod
+    def _make_full_time_index(irregular_time_index, freq):
+        full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq)
+        return full_time_index
+
+    def _plot(self, plt_type='hist', *args):
+        if plt_type == 'hist':
+            self._plot_hist()
+        elif plt_type == 'hist_cum':
+            self._plot_hist_cum()
+        else:
+            raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}")
+
+    def _plot_hist(self, *args):
+        colors = self.get_dataset_colors()
+        fig, axes = plt.subplots(figsize=(10, 3))
+        for i, subset in enumerate(self.dataset_time_interval.keys()):
+            plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset,
+                                                       self.temporal_dim: slice(
+                                                           self.dataset_time_interval[subset]['first'],
+                                                           self.dataset_time_interval[subset]['last']
+                                                       )
+                                                       }
+                                                      )
+
+            plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset)
+            plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset])
+
+        lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
+                         facecolor='white', framealpha=1, edgecolor='black')
+        for lgd_line in lgd.get_lines():
+            lgd_line.set_linewidth(4.0)
+        plt.gca().xaxis.set_major_locator(mdates.YearLocator())
+        plt.title('')
+        plt.ylabel('Number of samples')
+        plt.tight_layout()
+
+    def _plot_hist_cum(self, *args):
+        colors = self.get_dataset_colors()
+        fig, axes = plt.subplots(figsize=(10, 3))
+        n_bins = int(self.avail_data_cum_sum.max().values)
+        bins = np.arange(0, n_bins + 1)
+        descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby(
+            self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False
+        ).coords[self.subset_dim].values
+
+        for subset in descending_subsets:
+            self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes,
+                                                                             bins=bins,
+                                                                             label=subset,
+                                                                             cumulative=-1,
+                                                                             color=colors[subset],
+                                                                             # alpha=.5
+                                                                             )
+
+        lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
+                         facecolor='white', framealpha=1, edgecolor='black')
+        plt.title('')
+        plt.ylabel('Number of stations')
+        plt.xlabel('Number of samples')
+        plt.xlim((bins[0], bins[-1]))
+        plt.tight_layout()
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 73aebb00..ff74da37 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -19,9 +19,9 @@ from mlair.helpers.datastore import NameNotFoundInDataStore
 from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list, tables
 from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
 from mlair.model_modules import AbstractModelClass
-from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
-    PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotAvailabilityHistogram, \
-    PlotConditionalQuantiles, PlotSeparationOfScales
+from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \
+    PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, PlotSeparationOfScales
+from mlair.plotting.preprocessing_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram
 from mlair.run_modules.run_environment import RunEnvironment
 
 
-- 
GitLab