From 317f24468d546fc91a9801bc4f07496d6d31d0aa Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 13 Apr 2021 11:49:37 +0200 Subject: [PATCH] split plots into abstract, pre, and postprocessing plots --- mlair/plotting/abstract_plot_class.py | 101 +++++ mlair/plotting/postprocessing_plotting.py | 523 +--------------------- mlair/plotting/preprocessing_plotting.py | 438 ++++++++++++++++++ mlair/run_modules/post_processing.py | 6 +- 4 files changed, 543 insertions(+), 525 deletions(-) create mode 100644 mlair/plotting/abstract_plot_class.py create mode 100644 mlair/plotting/preprocessing_plotting.py diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py new file mode 100644 index 00000000..dab45156 --- /dev/null +++ b/mlair/plotting/abstract_plot_class.py @@ -0,0 +1,101 @@ +"""Abstract plot class that should be used for preprocessing and postprocessing plots.""" +__author__ = "Lukas Leufen" +__date__ = '2021-04-13' + +import logging +import os + +from matplotlib import pyplot as plt + + +class AbstractPlotClass: + """ + Abstract class for all plotting routines to unify plot workflow. + + Each inheritance requires a _plot method. Create a plot class like: + + .. code-block:: python + + class MyCustomPlot(AbstractPlotClass): + + def __init__(self, plot_folder, *args, **kwargs): + super().__init__(plot_folder, "custom_plot_name") + self._data = self._prepare_data(*args, **kwargs) + self._plot(*args, **kwargs) + self._save() + + def _prepare_data(*args, **kwargs): + <your custom data preparation> + return data + + def _plot(*args, **kwargs): + <your custom plotting without saving> + + The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are + using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be + set in super class initialisation). + + Methods like the shown _prepare_data() are optional. The only method required to implement is _plot. + + If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot + class. It will log the spent time if you call your plotting without saving the returned object. + + .. code-block:: python + + @TimeTrackingWrapper + class MyCustomPlot(AbstractPlotClass): + pass + + Let's assume it takes a while to create this very special plot. + + >>> MyCustomPlot() + INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss) + + """ + + def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None): + """Set up plot folder and name, and plot resolution (default 500dpi).""" + plot_folder = os.path.abspath(plot_folder) + if not os.path.exists(plot_folder): + os.makedirs(plot_folder) + self.plot_folder = plot_folder + self.plot_name = plot_name + self.resolution = resolution + if rc_params is None: + rc_params = {'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'large', + 'axes.titlesize': 'large', + } + self.rc_params = rc_params + self._update_rc_params() + + def _plot(self, *args): + """Abstract plot class needs to be implemented in inheritance.""" + raise NotImplementedError + + def _save(self, **kwargs): + """Store plot locally. Name of and path to plot need to be set on initialisation.""" + plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") + logging.debug(f"... save plot to {plot_name}") + plt.savefig(plot_name, dpi=self.resolution, **kwargs) + plt.close('all') + + def _update_rc_params(self): + plt.rcParams.update(self.rc_params) + + @staticmethod + def _get_sampling(sampling): + if sampling == "daily": + return "D" + elif sampling == "hourly": + return "h" + + @staticmethod + def get_dataset_colors(): + """ + Standard colors used for train-, val-, and test-sets during postprocessing + """ + colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code + return colors diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index d769fabc..491aa52e 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -9,10 +9,7 @@ import warnings from typing import Dict, List, Tuple import matplotlib -import matplotlib.patches as mpatches -import matplotlib.lines as mlines import matplotlib.pyplot as plt -import matplotlib.dates as mdates import numpy as np import pandas as pd import seaborn as sns @@ -22,6 +19,7 @@ from matplotlib.backends.backend_pdf import PdfPages from mlair import helpers from mlair.data_handler.iterator import DataCollection from mlair.helpers import TimeTrackingWrapper +from mlair.plotting.abstract_plot_class import AbstractPlotClass logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -31,100 +29,6 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING) # import matplotlib.pyplot as plt -class AbstractPlotClass: - """ - Abstract class for all plotting routines to unify plot workflow. - - Each inheritance requires a _plot method. Create a plot class like: - - .. code-block:: python - - class MyCustomPlot(AbstractPlotClass): - - def __init__(self, plot_folder, *args, **kwargs): - super().__init__(plot_folder, "custom_plot_name") - self._data = self._prepare_data(*args, **kwargs) - self._plot(*args, **kwargs) - self._save() - - def _prepare_data(*args, **kwargs): - <your custom data preparation> - return data - - def _plot(*args, **kwargs): - <your custom plotting without saving> - - The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are - using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be - set in super class initialisation). - - Methods like the shown _prepare_data() are optional. The only method required to implement is _plot. - - If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot - class. It will log the spent time if you call your plotting without saving the returned object. - - .. code-block:: python - - @TimeTrackingWrapper - class MyCustomPlot(AbstractPlotClass): - pass - - Let's assume it takes a while to create this very special plot. - - >>> MyCustomPlot() - INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss) - - """ - - def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None): - """Set up plot folder and name, and plot resolution (default 500dpi).""" - plot_folder = os.path.abspath(plot_folder) - if not os.path.exists(plot_folder): - os.makedirs(plot_folder) - self.plot_folder = plot_folder - self.plot_name = plot_name - self.resolution = resolution - if rc_params is None: - rc_params = {'axes.labelsize': 'large', - 'xtick.labelsize': 'large', - 'ytick.labelsize': 'large', - 'legend.fontsize': 'large', - 'axes.titlesize': 'large', - } - self.rc_params = rc_params - self._update_rc_params() - - def _plot(self, *args): - """Abstract plot class needs to be implemented in inheritance.""" - raise NotImplementedError - - def _save(self, **kwargs): - """Store plot locally. Name of and path to plot need to be set on initialisation.""" - plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") - logging.debug(f"... save plot to {plot_name}") - plt.savefig(plot_name, dpi=self.resolution, **kwargs) - plt.close('all') - - def _update_rc_params(self): - plt.rcParams.update(self.rc_params) - - @staticmethod - def _get_sampling(sampling): - if sampling == "daily": - return "D" - elif sampling == "hourly": - return "h" - - @staticmethod - def get_dataset_colors(): - """ - Standard colors used for train-, val-, and test-sets during postprocessing - """ - colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code - return colors - - - @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): """ @@ -230,132 +134,6 @@ class PlotMonthlySummary(AbstractPlotClass): plt.tight_layout() -@TimeTrackingWrapper -class PlotStationMap(AbstractPlotClass): - """ - Plot geographical overview of all used stations as squares. - - Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to - plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored - topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf - - .. image:: ../../../../../_source/_plots/station_map.png - :width: 400 - """ - - def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"): - """ - Set attributes and create plot. - - :param generators: dictionary with the plot color of each data set as key and the generator containing all stations - as value. - :param plot_folder: path to save the plot (default: current directory) - """ - super().__init__(plot_folder, plot_name) - self._ax = None - self._gl = None - self._plot(generators) - self._save(bbox_inches="tight") - - def _draw_background(self): - """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" - - import cartopy.feature as cfeature - - self._ax.add_feature(cfeature.LAND.with_scale("50m")) - self._ax.natural_earth_shp(resolution='50m') - self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') - self._ax.add_feature(cfeature.LAKES.with_scale("50m")) - self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) - self._ax.add_feature(cfeature.RIVERS.with_scale("50m")) - self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') - - def _plot_stations(self, generators): - """ - Loop over all keys in generators dict and its containing stations and plot the stations's position. - - Position is highlighted by a square on the map regarding the given color. - - :param generators: dictionary with the plot color of each data set as key and the generator containing all - stations as value. - """ - - import cartopy.crs as ccrs - if generators is not None: - legend_elements = [] - default_colors = self.get_dataset_colors() - for element in generators: - data_collection, plot_opts = self._get_collection_and_opts(element) - name = data_collection.name or "unknown" - marker = plot_opts.get("marker", "s") - ms = plot_opts.get("ms", 6) - mec = plot_opts.get("mec", "k") - mfc = plot_opts.get("mfc", default_colors.get(name, "b")) - legend_elements.append( - mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', - label=f"{name} ({len(data_collection)})")) - for station in data_collection: - coords = station.get_coordinates() - IDx, IDy = coords["lon"], coords["lat"] - self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) - if len(legend_elements) > 0: - self._ax.legend(handles=legend_elements, loc='best') - - @staticmethod - def _adjust_marker(marker): - _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} - if isinstance(marker, int) and marker in _adjust.keys(): - return _adjust[marker] - else: - return marker - - @staticmethod - def _get_collection_and_opts(element): - if isinstance(element, tuple): - if len(element) == 1: - return element[0], {} - else: - return element - else: - return element, {} - - def _plot(self, generators: List): - """ - Create the station map plot. - - Set figure and call all required sub-methods. - - :param generators: dictionary with the plot color of each data set as key and the generator containing all - stations as value. - """ - - import cartopy.crs as ccrs - from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER - fig = plt.figure(figsize=(10, 5)) - self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) - self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) - self._gl.xformatter = LONGITUDE_FORMATTER - self._gl.yformatter = LATITUDE_FORMATTER - self._draw_background() - self._plot_stations(generators) - self._adjust_extent() - plt.tight_layout() - - def _adjust_extent(self): - import cartopy.crs as ccrs - - def diff(arr): - return arr[1] - arr[0], arr[3] - arr[2] - - def find_ratio(delta, reference=5): - return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5) - - extent = self._ax.get_extent(crs=ccrs.PlateCarree()) - ratio = find_ratio(diff(extent)) - new_extent = extent + np.array([-1, 1, -1, 1]) * ratio - self._ax.set_extent(new_extent, crs=ccrs.PlateCarree()) - - @TimeTrackingWrapper class PlotConditionalQuantiles(AbstractPlotClass): """ @@ -1138,133 +916,6 @@ class PlotTimeSeries: return matplotlib.backends.backend_pdf.PdfPages(plot_name) -@TimeTrackingWrapper -class PlotAvailability(AbstractPlotClass): - """ - Create data availablility plot similar to Gantt plot. - - Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal - resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a - colored bar or a blank space. - - You can set different colors to highlight subsets for example by providing different generators for the same index - using different keys in the input dictionary. - - Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs - in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset. - - Calling this class will create three versions fo the availability plot. - - 1) Data availability for each element - 1) Data availability as summary over all elements (is there at least a single elemnt for each time step) - 1) Combination of single and overall availability - - .. image:: ../../../../../_source/_plots/data_availability.png - :width: 400 - - .. image:: ../../../../../_source/_plots/data_availability_summary.png - :width: 400 - - .. image:: ../../../../../_source/_plots/data_availability_combined.png - :width: 400 - - """ - - def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily", - summary_name="data availability", time_dimension="datetime", window_dimension="window"): - """Initialise.""" - # create standard Gantt plot for all stations (currently in single pdf file with single page) - super().__init__(plot_folder, "data_availability") - self.time_dim = time_dimension - self.window_dim = window_dimension - self.sampling = self._get_sampling(sampling) - self.linewidth = None - if self.sampling == 'h': - self.linewidth = 0.001 - plot_dict = self._prepare_data(generators) - lgd = self._plot(plot_dict) - self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") - # create summary Gantt plot (is data in at least one station available) - self.plot_name += "_summary" - plot_dict_summary = self._summarise_data(generators, summary_name) - lgd = self._plot(plot_dict_summary) - self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") - # combination of station and summary plot, last element is summary broken bar - self.plot_name = "data_availability_combined" - plot_dict_summary.update(plot_dict) - lgd = self._plot(plot_dict_summary) - self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") - - def _prepare_data(self, generators: Dict[str, DataCollection]): - plt_dict = {} - for subset, data_collection in generators.items(): - for station in data_collection: - labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() - labels_bool = labels.sel(**{self.window_dim: 1}).notnull() - group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum() - plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, - index=labels.coords[self.time_dim].values) - t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) - t2 = [i[1:] for i in t if i[0]] - - if plt_dict.get(str(station)) is None: - plt_dict[str(station)] = {subset: t2} - else: - plt_dict[str(station)].update({subset: t2}) - return plt_dict - - def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str): - plt_dict = {} - for subset, data_collection in generators.items(): - all_data = None - for station in data_collection: - labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() - labels_bool = labels.sel(**{self.window_dim: 1}).notnull() - if all_data is None: - all_data = labels_bool - else: - tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords - all_data = np.logical_or(tmp, labels_bool).combine_first( - all_data) # apply logical on merge and fill missing with all_data - - group = (all_data != all_data.shift({self.time_dim: 1})).cumsum() - plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, - index=all_data.coords[self.time_dim].values) - t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) - t2 = [i[1:] for i in t if i[0]] - if plt_dict.get(summary_name) is None: - plt_dict[summary_name] = {subset: t2} - else: - plt_dict[summary_name].update({subset: t2}) - return plt_dict - - def _plot(self, plt_dict): - colors = self.get_dataset_colors() - _used_colors = [] - pos = 0 - height = 0.8 # should be <= 1 - yticklabels = [] - number_of_stations = len(plt_dict.keys()) - fig, ax = plt.subplots(figsize=(10, number_of_stations / 3)) - for station, d in sorted(plt_dict.items(), reverse=True): - pos += 1 - for subset, color in colors.items(): - plt_data = d.get(subset) - if plt_data is None: - continue - elif color not in _used_colors: # this is required for a proper legend creation - _used_colors.append(color) - ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth) - yticklabels.append(station) - - ax.set_ylim([height, number_of_stations + 1]) - ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) - ax.set_yticklabels(yticklabels) - handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors] - lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) - return lgd - - @TimeTrackingWrapper class PlotSeparationOfScales(AbstractPlotClass): @@ -1292,178 +943,6 @@ class PlotSeparationOfScales(AbstractPlotClass): self._save() -@TimeTrackingWrapper -class PlotAvailabilityHistogram(AbstractPlotClass): - """ - Create data availability plots as histogram. - - Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean). - Calling this class creates two different types of histograms where each generator - - 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis) - 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number - of samples (yaxis) - - .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png - :width: 400 - - .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png - :width: 400 - - """ - - def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", - subset_dim: str = 'DataSet', history_dim: str = 'window', - station_dim: str = 'Stations',): - - super().__init__(plot_folder, "data_availability_histogram") - - self.subset_dim = subset_dim - self.history_dim = history_dim - self.station_dim = station_dim - - self.freq = None - self.temporal_dim = None - self.target_dim = None - self._prepare_data(generators) - - for plt_type in self.allowed_plot_types: - plot_name_tmp = self.plot_name - self.plot_name += '_' + plt_type - self._plot(plt_type=plt_type) - self._save() - self.plot_name = plot_name_tmp - - def _set_dims_from_datahandler(self, data_handler): - self.temporal_dim = data_handler.id_class.time_dim - self.target_dim = data_handler.id_class.target_dim - self.freq = self._get_sampling(data_handler.id_class.sampling) - - @property - def allowed_plot_types(self): - plot_types = ['hist', 'hist_cum'] - return plot_types - - def _prepare_data(self, generators: Dict[str, DataCollection]): - """ - Prepares data to be used by plot methods. - - Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim - """ - avail_data_time_sum = {} - avail_data_station_sum = {} - dataset_time_interval = {} - for subset, generator in generators.items(): - avail_list = [] - for station in generator: - self._set_dims_from_datahandler(data_handler=station) - station_data_x = station.get_X(as_numpy=False)[0] - station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame - self.target_dim: station_data_x[self.target_dim].values[0]}] - station_data_x = self._reduce_dims(station_data_x) - avail_list.append(station_data_x.notnull()) - avail_data = xr.concat(avail_list, dim=self.station_dim).notnull() - avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim) - avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim) - dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray( - avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict' - ) - avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(), - name=self.subset_dim) - ) - full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq) - self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(), - name=self.subset_dim)) - self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index}) - self.dataset_time_interval = dataset_time_interval - - def _reduce_dims(self, dataset): - if len(dataset.dims) > 2: - required = {self.temporal_dim, self.station_dim} - unimportant = set(dataset.dims).difference(required) - sel_dict = {un: dataset[un].values[0] for un in unimportant} - dataset = dataset.loc[sel_dict] - return dataset - - @staticmethod - def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'): - if isinstance(xarray, xr.DataArray): - first = xarray.coords[dim_name].values[0] - last = xarray.coords[dim_name].values[-1] - if return_type == 'as_tuple': - return first, last - elif return_type == 'as_dict': - return {'first': first, 'last': last} - else: - raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'") - else: - raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}") - - @staticmethod - def _make_full_time_index(irregular_time_index, freq): - full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq) - return full_time_index - - def _plot(self, plt_type='hist', *args): - if plt_type == 'hist': - self._plot_hist() - elif plt_type == 'hist_cum': - self._plot_hist_cum() - else: - raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}") - - def _plot_hist(self, *args): - colors = self.get_dataset_colors() - fig, axes = plt.subplots(figsize=(10, 3)) - for i, subset in enumerate(self.dataset_time_interval.keys()): - plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset, - self.temporal_dim: slice( - self.dataset_time_interval[subset]['first'], - self.dataset_time_interval[subset]['last'] - ) - } - ) - - plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset) - plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset]) - - lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), - facecolor='white', framealpha=1, edgecolor='black') - for lgd_line in lgd.get_lines(): - lgd_line.set_linewidth(4.0) - plt.gca().xaxis.set_major_locator(mdates.YearLocator()) - plt.title('') - plt.ylabel('Number of samples') - plt.tight_layout() - - def _plot_hist_cum(self, *args): - colors = self.get_dataset_colors() - fig, axes = plt.subplots(figsize=(10, 3)) - n_bins = int(self.avail_data_cum_sum.max().values) - bins = np.arange(0, n_bins+1) - descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby( - self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False - ).coords[self.subset_dim].values - - for subset in descending_subsets: - self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes, - bins=bins, - label=subset, - cumulative=-1, - color=colors[subset], - # alpha=.5 - ) - - lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), - facecolor='white', framealpha=1, edgecolor='black') - plt.title('') - plt.ylabel('Number of stations') - plt.xlabel('Number of samples') - plt.xlim((bins[0], bins[-1])) - plt.tight_layout() - - - if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" diff --git a/mlair/plotting/preprocessing_plotting.py b/mlair/plotting/preprocessing_plotting.py new file mode 100644 index 00000000..aa61b1f3 --- /dev/null +++ b/mlair/plotting/preprocessing_plotting.py @@ -0,0 +1,438 @@ +"""Collection of plots to get more insight into data.""" +__author__ = "Lukas Leufen, Felix Kleinert" +__date__ = '2021-04-13' + +from typing import List, Dict + +import numpy as np +import pandas as pd +import xarray as xr +from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates + +from mlair.data_handler import DataCollection +from mlair.helpers import TimeTrackingWrapper +from mlair.plotting.abstract_plot_class import AbstractPlotClass + + +@TimeTrackingWrapper +class PlotStationMap(AbstractPlotClass): + """ + Plot geographical overview of all used stations as squares. + + Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to + plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored + topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf + + .. image:: ../../../../../_source/_plots/station_map.png + :width: 400 + """ + + def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"): + """ + Set attributes and create plot. + + :param generators: dictionary with the plot color of each data set as key and the generator containing all stations + as value. + :param plot_folder: path to save the plot (default: current directory) + """ + super().__init__(plot_folder, plot_name) + self._ax = None + self._gl = None + self._plot(generators) + self._save(bbox_inches="tight") + + def _draw_background(self): + """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" + + import cartopy.feature as cfeature + + self._ax.add_feature(cfeature.LAND.with_scale("50m")) + self._ax.natural_earth_shp(resolution='50m') + self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') + self._ax.add_feature(cfeature.LAKES.with_scale("50m")) + self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) + self._ax.add_feature(cfeature.RIVERS.with_scale("50m")) + self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') + + def _plot_stations(self, generators): + """ + Loop over all keys in generators dict and its containing stations and plot the stations's position. + + Position is highlighted by a square on the map regarding the given color. + + :param generators: dictionary with the plot color of each data set as key and the generator containing all + stations as value. + """ + + import cartopy.crs as ccrs + if generators is not None: + legend_elements = [] + default_colors = self.get_dataset_colors() + for element in generators: + data_collection, plot_opts = self._get_collection_and_opts(element) + name = data_collection.name or "unknown" + marker = plot_opts.get("marker", "s") + ms = plot_opts.get("ms", 6) + mec = plot_opts.get("mec", "k") + mfc = plot_opts.get("mfc", default_colors.get(name, "b")) + legend_elements.append( + mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', + label=f"{name} ({len(data_collection)})")) + for station in data_collection: + coords = station.get_coordinates() + IDx, IDy = coords["lon"], coords["lat"] + self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) + if len(legend_elements) > 0: + self._ax.legend(handles=legend_elements, loc='best') + + @staticmethod + def _adjust_marker(marker): + _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} + if isinstance(marker, int) and marker in _adjust.keys(): + return _adjust[marker] + else: + return marker + + @staticmethod + def _get_collection_and_opts(element): + if isinstance(element, tuple): + if len(element) == 1: + return element[0], {} + else: + return element + else: + return element, {} + + def _plot(self, generators: List): + """ + Create the station map plot. + + Set figure and call all required sub-methods. + + :param generators: dictionary with the plot color of each data set as key and the generator containing all + stations as value. + """ + + import cartopy.crs as ccrs + from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER + fig = plt.figure(figsize=(10, 5)) + self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) + self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) + self._gl.xformatter = LONGITUDE_FORMATTER + self._gl.yformatter = LATITUDE_FORMATTER + self._draw_background() + self._plot_stations(generators) + self._adjust_extent() + plt.tight_layout() + + def _adjust_extent(self): + import cartopy.crs as ccrs + + def diff(arr): + return arr[1] - arr[0], arr[3] - arr[2] + + def find_ratio(delta, reference=5): + return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5) + + extent = self._ax.get_extent(crs=ccrs.PlateCarree()) + ratio = find_ratio(diff(extent)) + new_extent = extent + np.array([-1, 1, -1, 1]) * ratio + self._ax.set_extent(new_extent, crs=ccrs.PlateCarree()) + + +@TimeTrackingWrapper +class PlotAvailability(AbstractPlotClass): + """ + Create data availablility plot similar to Gantt plot. + + Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal + resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a + colored bar or a blank space. + + You can set different colors to highlight subsets for example by providing different generators for the same index + using different keys in the input dictionary. + + Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs + in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset. + + Calling this class will create three versions fo the availability plot. + + 1) Data availability for each element + 1) Data availability as summary over all elements (is there at least a single elemnt for each time step) + 1) Combination of single and overall availability + + .. image:: ../../../../../_source/_plots/data_availability.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_summary.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_combined.png + :width: 400 + + """ + + def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily", + summary_name="data availability", time_dimension="datetime", window_dimension="window"): + """Initialise.""" + # create standard Gantt plot for all stations (currently in single pdf file with single page) + super().__init__(plot_folder, "data_availability") + self.time_dim = time_dimension + self.window_dim = window_dimension + self.sampling = self._get_sampling(sampling) + self.linewidth = None + if self.sampling == 'h': + self.linewidth = 0.001 + plot_dict = self._prepare_data(generators) + lgd = self._plot(plot_dict) + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") + # create summary Gantt plot (is data in at least one station available) + self.plot_name += "_summary" + plot_dict_summary = self._summarise_data(generators, summary_name) + lgd = self._plot(plot_dict_summary) + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") + # combination of station and summary plot, last element is summary broken bar + self.plot_name = "data_availability_combined" + plot_dict_summary.update(plot_dict) + lgd = self._plot(plot_dict_summary) + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") + + def _prepare_data(self, generators: Dict[str, DataCollection]): + plt_dict = {} + for subset, data_collection in generators.items(): + for station in data_collection: + labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() + labels_bool = labels.sel(**{self.window_dim: 1}).notnull() + group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum() + plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, + index=labels.coords[self.time_dim].values) + t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) + t2 = [i[1:] for i in t if i[0]] + + if plt_dict.get(str(station)) is None: + plt_dict[str(station)] = {subset: t2} + else: + plt_dict[str(station)].update({subset: t2}) + return plt_dict + + def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str): + plt_dict = {} + for subset, data_collection in generators.items(): + all_data = None + for station in data_collection: + labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() + labels_bool = labels.sel(**{self.window_dim: 1}).notnull() + if all_data is None: + all_data = labels_bool + else: + tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords + all_data = np.logical_or(tmp, labels_bool).combine_first( + all_data) # apply logical on merge and fill missing with all_data + + group = (all_data != all_data.shift({self.time_dim: 1})).cumsum() + plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, + index=all_data.coords[self.time_dim].values) + t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) + t2 = [i[1:] for i in t if i[0]] + if plt_dict.get(summary_name) is None: + plt_dict[summary_name] = {subset: t2} + else: + plt_dict[summary_name].update({subset: t2}) + return plt_dict + + def _plot(self, plt_dict): + colors = self.get_dataset_colors() + _used_colors = [] + pos = 0 + height = 0.8 # should be <= 1 + yticklabels = [] + number_of_stations = len(plt_dict.keys()) + fig, ax = plt.subplots(figsize=(10, number_of_stations / 3)) + for station, d in sorted(plt_dict.items(), reverse=True): + pos += 1 + for subset, color in colors.items(): + plt_data = d.get(subset) + if plt_data is None: + continue + elif color not in _used_colors: # this is required for a proper legend creation + _used_colors.append(color) + ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth) + yticklabels.append(station) + + ax.set_ylim([height, number_of_stations + 1]) + ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) + ax.set_yticklabels(yticklabels) + handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors] + lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) + return lgd + + +@TimeTrackingWrapper +class PlotAvailabilityHistogram(AbstractPlotClass): + """ + Create data availability plots as histogram. + + Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean). + Calling this class creates two different types of histograms where each generator + + 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis) + 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number + of samples (yaxis) + + .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png + :width: 400 + + """ + + def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", + subset_dim: str = 'DataSet', history_dim: str = 'window', + station_dim: str = 'Stations', ): + + super().__init__(plot_folder, "data_availability_histogram") + + self.subset_dim = subset_dim + self.history_dim = history_dim + self.station_dim = station_dim + + self.freq = None + self.temporal_dim = None + self.target_dim = None + self._prepare_data(generators) + + for plt_type in self.allowed_plot_types: + plot_name_tmp = self.plot_name + self.plot_name += '_' + plt_type + self._plot(plt_type=plt_type) + self._save() + self.plot_name = plot_name_tmp + + def _set_dims_from_datahandler(self, data_handler): + self.temporal_dim = data_handler.id_class.time_dim + self.target_dim = data_handler.id_class.target_dim + self.freq = self._get_sampling(data_handler.id_class.sampling) + + @property + def allowed_plot_types(self): + plot_types = ['hist', 'hist_cum'] + return plot_types + + def _prepare_data(self, generators: Dict[str, DataCollection]): + """ + Prepares data to be used by plot methods. + + Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim + """ + avail_data_time_sum = {} + avail_data_station_sum = {} + dataset_time_interval = {} + for subset, generator in generators.items(): + avail_list = [] + for station in generator: + self._set_dims_from_datahandler(data_handler=station) + station_data_x = station.get_X(as_numpy=False)[0] + station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame + self.target_dim: station_data_x[self.target_dim].values[0]}] + station_data_x = self._reduce_dims(station_data_x) + avail_list.append(station_data_x.notnull()) + avail_data = xr.concat(avail_list, dim=self.station_dim).notnull() + avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim) + avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim) + dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray( + avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict' + ) + avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(), + name=self.subset_dim) + ) + full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq) + self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(), + name=self.subset_dim)) + self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index}) + self.dataset_time_interval = dataset_time_interval + + def _reduce_dims(self, dataset): + if len(dataset.dims) > 2: + required = {self.temporal_dim, self.station_dim} + unimportant = set(dataset.dims).difference(required) + sel_dict = {un: dataset[un].values[0] for un in unimportant} + dataset = dataset.loc[sel_dict] + return dataset + + @staticmethod + def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'): + if isinstance(xarray, xr.DataArray): + first = xarray.coords[dim_name].values[0] + last = xarray.coords[dim_name].values[-1] + if return_type == 'as_tuple': + return first, last + elif return_type == 'as_dict': + return {'first': first, 'last': last} + else: + raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'") + else: + raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}") + + @staticmethod + def _make_full_time_index(irregular_time_index, freq): + full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq) + return full_time_index + + def _plot(self, plt_type='hist', *args): + if plt_type == 'hist': + self._plot_hist() + elif plt_type == 'hist_cum': + self._plot_hist_cum() + else: + raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}") + + def _plot_hist(self, *args): + colors = self.get_dataset_colors() + fig, axes = plt.subplots(figsize=(10, 3)) + for i, subset in enumerate(self.dataset_time_interval.keys()): + plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset, + self.temporal_dim: slice( + self.dataset_time_interval[subset]['first'], + self.dataset_time_interval[subset]['last'] + ) + } + ) + + plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset) + plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset]) + + lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), + facecolor='white', framealpha=1, edgecolor='black') + for lgd_line in lgd.get_lines(): + lgd_line.set_linewidth(4.0) + plt.gca().xaxis.set_major_locator(mdates.YearLocator()) + plt.title('') + plt.ylabel('Number of samples') + plt.tight_layout() + + def _plot_hist_cum(self, *args): + colors = self.get_dataset_colors() + fig, axes = plt.subplots(figsize=(10, 3)) + n_bins = int(self.avail_data_cum_sum.max().values) + bins = np.arange(0, n_bins + 1) + descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby( + self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False + ).coords[self.subset_dim].values + + for subset in descending_subsets: + self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes, + bins=bins, + label=subset, + cumulative=-1, + color=colors[subset], + # alpha=.5 + ) + + lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), + facecolor='white', framealpha=1, edgecolor='black') + plt.title('') + plt.ylabel('Number of stations') + plt.xlabel('Number of samples') + plt.xlim((bins[0], bins[-1])) + plt.tight_layout() diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 73aebb00..ff74da37 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -19,9 +19,9 @@ from mlair.helpers.datastore import NameNotFoundInDataStore from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list, tables from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass -from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotAvailabilityHistogram, \ - PlotConditionalQuantiles, PlotSeparationOfScales +from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ + PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, PlotSeparationOfScales +from mlair.plotting.preprocessing_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram from mlair.run_modules.run_environment import RunEnvironment -- GitLab