diff --git a/HPC_setup/requirements_HDFML_additionals.txt b/HPC_setup/requirements_HDFML_additionals.txt index b2a29fbfb353f24d8c99d8429693022ea1fd406f..fd22a309913efa6478a4a00f94bac70433e21774 100644 --- a/HPC_setup/requirements_HDFML_additionals.txt +++ b/HPC_setup/requirements_HDFML_additionals.txt @@ -1,6 +1,7 @@ absl-py==0.11.0 appdirs==1.4.4 astor==0.8.1 +astropy==4.1 attrs==20.3.0 bottleneck==1.3.2 cached-property==1.5.2 diff --git a/HPC_setup/requirements_JUWELS_additionals.txt b/HPC_setup/requirements_JUWELS_additionals.txt index b2a29fbfb353f24d8c99d8429693022ea1fd406f..fd22a309913efa6478a4a00f94bac70433e21774 100644 --- a/HPC_setup/requirements_JUWELS_additionals.txt +++ b/HPC_setup/requirements_JUWELS_additionals.txt @@ -1,6 +1,7 @@ absl-py==0.11.0 appdirs==1.4.4 astor==0.8.1 +astropy==4.1 attrs==20.3.0 bottleneck==1.3.2 cached-property==1.5.2 diff --git a/docs/_source/_plots/periodogram.png b/docs/_source/_plots/periodogram.png new file mode 100644 index 0000000000000000000000000000000000000000..a756cffab18869f615ff504303f2743618f14633 Binary files /dev/null and b/docs/_source/_plots/periodogram.png differ diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py index 75e9e64506231f32406934b67e65454d87a43f61..e25162572582a032361d287fd73c3386cbaf438e 100644 --- a/mlair/data_handler/data_handler_mixed_sampling.py +++ b/mlair/data_handler/data_handler_mixed_sampling.py @@ -56,7 +56,7 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation): kwargs.update({parameter_name: parameter}) def make_input_target(self): - self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data + self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data self.set_inputs_and_targets() def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]: @@ -110,7 +110,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values with daily resolution. """ - self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data + self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data self.set_inputs_and_targets() self.apply_kz_filter() @@ -158,7 +158,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi def _extract_lazy(self, lazy_data): _data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data start_inp, end_inp = self.update_start_end(0) - self._data = list(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1])) + self._data = tuple(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1])) self.input_data = self._slice_prep(_input_data, start_inp, end_inp) self.target_data = self._slice_prep(_target_data, self.start, self.end) diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py index 87fc83b0c5d97631b9b0e01aa490be20c107ed1f..3a57d9febc714c81a68c21facab55957eabf32d9 100644 --- a/mlair/data_handler/default_data_handler.py +++ b/mlair/data_handler/default_data_handler.py @@ -299,6 +299,7 @@ class DefaultDataHandler(AbstractDataHandler): for p in output: dh, s = p.get() _inner() + pool.close() else: # serial solution logging.info("use serial transformation approach") for station in set_stations: diff --git a/mlair/plotting/abstract_plot_class.py b/mlair/plotting/abstract_plot_class.py new file mode 100644 index 0000000000000000000000000000000000000000..dab45156ac1bbe033ba073e01245ffc8b65ca6b3 --- /dev/null +++ b/mlair/plotting/abstract_plot_class.py @@ -0,0 +1,101 @@ +"""Abstract plot class that should be used for preprocessing and postprocessing plots.""" +__author__ = "Lukas Leufen" +__date__ = '2021-04-13' + +import logging +import os + +from matplotlib import pyplot as plt + + +class AbstractPlotClass: + """ + Abstract class for all plotting routines to unify plot workflow. + + Each inheritance requires a _plot method. Create a plot class like: + + .. code-block:: python + + class MyCustomPlot(AbstractPlotClass): + + def __init__(self, plot_folder, *args, **kwargs): + super().__init__(plot_folder, "custom_plot_name") + self._data = self._prepare_data(*args, **kwargs) + self._plot(*args, **kwargs) + self._save() + + def _prepare_data(*args, **kwargs): + <your custom data preparation> + return data + + def _plot(*args, **kwargs): + <your custom plotting without saving> + + The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are + using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be + set in super class initialisation). + + Methods like the shown _prepare_data() are optional. The only method required to implement is _plot. + + If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot + class. It will log the spent time if you call your plotting without saving the returned object. + + .. code-block:: python + + @TimeTrackingWrapper + class MyCustomPlot(AbstractPlotClass): + pass + + Let's assume it takes a while to create this very special plot. + + >>> MyCustomPlot() + INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss) + + """ + + def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None): + """Set up plot folder and name, and plot resolution (default 500dpi).""" + plot_folder = os.path.abspath(plot_folder) + if not os.path.exists(plot_folder): + os.makedirs(plot_folder) + self.plot_folder = plot_folder + self.plot_name = plot_name + self.resolution = resolution + if rc_params is None: + rc_params = {'axes.labelsize': 'large', + 'xtick.labelsize': 'large', + 'ytick.labelsize': 'large', + 'legend.fontsize': 'large', + 'axes.titlesize': 'large', + } + self.rc_params = rc_params + self._update_rc_params() + + def _plot(self, *args): + """Abstract plot class needs to be implemented in inheritance.""" + raise NotImplementedError + + def _save(self, **kwargs): + """Store plot locally. Name of and path to plot need to be set on initialisation.""" + plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") + logging.debug(f"... save plot to {plot_name}") + plt.savefig(plot_name, dpi=self.resolution, **kwargs) + plt.close('all') + + def _update_rc_params(self): + plt.rcParams.update(self.rc_params) + + @staticmethod + def _get_sampling(sampling): + if sampling == "daily": + return "D" + elif sampling == "hourly": + return "h" + + @staticmethod + def get_dataset_colors(): + """ + Standard colors used for train-, val-, and test-sets during postprocessing + """ + colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code + return colors diff --git a/mlair/plotting/data_insight_plotting.py b/mlair/plotting/data_insight_plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..1176621a71f09e6efff4ac21a69e4f466e6dfbd4 --- /dev/null +++ b/mlair/plotting/data_insight_plotting.py @@ -0,0 +1,721 @@ +"""Collection of plots to get more insight into data.""" +__author__ = "Lukas Leufen, Felix Kleinert" +__date__ = '2021-04-13' + +from typing import List, Dict +import os +import logging +import multiprocessing +import psutil + +import numpy as np +import pandas as pd +import xarray as xr +import matplotlib +from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates +from astropy.timeseries import LombScargle + +from mlair.data_handler import DataCollection +from mlair.helpers import TimeTrackingWrapper, to_list +from mlair.plotting.abstract_plot_class import AbstractPlotClass + + +@TimeTrackingWrapper +class PlotStationMap(AbstractPlotClass): # pragma: no cover + """ + Plot geographical overview of all used stations as squares. + + Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to + plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored + topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf + + .. image:: ../../../../../_source/_plots/station_map.png + :width: 400 + """ + + def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"): + """ + Set attributes and create plot. + + :param generators: dictionary with the plot color of each data set as key and the generator containing all stations + as value. + :param plot_folder: path to save the plot (default: current directory) + """ + super().__init__(plot_folder, plot_name) + self._ax = None + self._gl = None + self._plot(generators) + self._save(bbox_inches="tight") + + def _draw_background(self): + """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" + + import cartopy.feature as cfeature + + self._ax.add_feature(cfeature.LAND.with_scale("50m")) + self._ax.natural_earth_shp(resolution='50m') + self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') + self._ax.add_feature(cfeature.LAKES.with_scale("50m")) + self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) + self._ax.add_feature(cfeature.RIVERS.with_scale("50m")) + self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') + + def _plot_stations(self, generators): + """ + Loop over all keys in generators dict and its containing stations and plot the stations's position. + + Position is highlighted by a square on the map regarding the given color. + + :param generators: dictionary with the plot color of each data set as key and the generator containing all + stations as value. + """ + + import cartopy.crs as ccrs + if generators is not None: + legend_elements = [] + default_colors = self.get_dataset_colors() + for element in generators: + data_collection, plot_opts = self._get_collection_and_opts(element) + name = data_collection.name or "unknown" + marker = plot_opts.get("marker", "s") + ms = plot_opts.get("ms", 6) + mec = plot_opts.get("mec", "k") + mfc = plot_opts.get("mfc", default_colors.get(name, "b")) + legend_elements.append( + mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', + label=f"{name} ({len(data_collection)})")) + for station in data_collection: + coords = station.get_coordinates() + IDx, IDy = coords["lon"], coords["lat"] + self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) + if len(legend_elements) > 0: + self._ax.legend(handles=legend_elements, loc='best') + + @staticmethod + def _adjust_marker(marker): + _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} + if isinstance(marker, int) and marker in _adjust.keys(): + return _adjust[marker] + else: + return marker + + @staticmethod + def _get_collection_and_opts(element): + if isinstance(element, tuple): + if len(element) == 1: + return element[0], {} + else: + return element + else: + return element, {} + + def _plot(self, generators: List): + """ + Create the station map plot. + + Set figure and call all required sub-methods. + + :param generators: dictionary with the plot color of each data set as key and the generator containing all + stations as value. + """ + + import cartopy.crs as ccrs + from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER + fig = plt.figure(figsize=(10, 5)) + self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) + self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) + self._gl.xformatter = LONGITUDE_FORMATTER + self._gl.yformatter = LATITUDE_FORMATTER + self._draw_background() + self._plot_stations(generators) + self._adjust_extent() + plt.tight_layout() + + def _adjust_extent(self): + import cartopy.crs as ccrs + + def diff(arr): + return arr[1] - arr[0], arr[3] - arr[2] + + def find_ratio(delta, reference=5): + return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5) + + extent = self._ax.get_extent(crs=ccrs.PlateCarree()) + ratio = find_ratio(diff(extent)) + new_extent = extent + np.array([-1, 1, -1, 1]) * ratio + self._ax.set_extent(new_extent, crs=ccrs.PlateCarree()) + + +@TimeTrackingWrapper +class PlotAvailability(AbstractPlotClass): # pragma: no cover + """ + Create data availablility plot similar to Gantt plot. + + Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal + resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a + colored bar or a blank space. + + You can set different colors to highlight subsets for example by providing different generators for the same index + using different keys in the input dictionary. + + Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs + in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset. + + Calling this class will create three versions fo the availability plot. + + 1) Data availability for each element + 1) Data availability as summary over all elements (is there at least a single elemnt for each time step) + 1) Combination of single and overall availability + + .. image:: ../../../../../_source/_plots/data_availability.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_summary.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_combined.png + :width: 400 + + """ + + def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily", + summary_name="data availability", time_dimension="datetime", window_dimension="window"): + """Initialise.""" + # create standard Gantt plot for all stations (currently in single pdf file with single page) + super().__init__(plot_folder, "data_availability") + self.time_dim = time_dimension + self.window_dim = window_dimension + self.sampling = self._get_sampling(sampling) + self.linewidth = None + if self.sampling == 'h': + self.linewidth = 0.001 + plot_dict = self._prepare_data(generators) + lgd = self._plot(plot_dict) + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") + # create summary Gantt plot (is data in at least one station available) + self.plot_name += "_summary" + plot_dict_summary = self._summarise_data(generators, summary_name) + lgd = self._plot(plot_dict_summary) + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") + # combination of station and summary plot, last element is summary broken bar + self.plot_name = "data_availability_combined" + plot_dict_summary.update(plot_dict) + lgd = self._plot(plot_dict_summary) + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") + + def _prepare_data(self, generators: Dict[str, DataCollection]): + plt_dict = {} + for subset, data_collection in generators.items(): + for station in data_collection: + labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() + labels_bool = labels.sel(**{self.window_dim: 1}).notnull() + group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum() + plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, + index=labels.coords[self.time_dim].values) + t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) + t2 = [i[1:] for i in t if i[0]] + + if plt_dict.get(str(station)) is None: + plt_dict[str(station)] = {subset: t2} + else: + plt_dict[str(station)].update({subset: t2}) + return plt_dict + + def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str): + plt_dict = {} + for subset, data_collection in generators.items(): + all_data = None + for station in data_collection: + labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() + labels_bool = labels.sel(**{self.window_dim: 1}).notnull() + if all_data is None: + all_data = labels_bool + else: + tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords + all_data = np.logical_or(tmp, labels_bool).combine_first( + all_data) # apply logical on merge and fill missing with all_data + + group = (all_data != all_data.shift({self.time_dim: 1})).cumsum() + plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, + index=all_data.coords[self.time_dim].values) + t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) + t2 = [i[1:] for i in t if i[0]] + if plt_dict.get(summary_name) is None: + plt_dict[summary_name] = {subset: t2} + else: + plt_dict[summary_name].update({subset: t2}) + return plt_dict + + def _plot(self, plt_dict): + colors = self.get_dataset_colors() + _used_colors = [] + pos = 0 + height = 0.8 # should be <= 1 + yticklabels = [] + number_of_stations = len(plt_dict.keys()) + fig, ax = plt.subplots(figsize=(10, number_of_stations / 3)) + for station, d in sorted(plt_dict.items(), reverse=True): + pos += 1 + for subset, color in colors.items(): + plt_data = d.get(subset) + if plt_data is None: + continue + elif color not in _used_colors: # this is required for a proper legend creation + _used_colors.append(color) + ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth) + yticklabels.append(station) + + ax.set_ylim([height, number_of_stations + 1]) + ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) + ax.set_yticklabels(yticklabels) + handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors] + lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) + return lgd + + +@TimeTrackingWrapper +class PlotAvailabilityHistogram(AbstractPlotClass): # pragma: no cover + """ + Create data availability plots as histogram. + + Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean). + Calling this class creates two different types of histograms where each generator + + 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis) + 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number + of samples (yaxis) + + .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png + :width: 400 + + """ + + def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", + subset_dim: str = 'DataSet', history_dim: str = 'window', + station_dim: str = 'Stations', ): + + super().__init__(plot_folder, "data_availability_histogram") + + self.subset_dim = subset_dim + self.history_dim = history_dim + self.station_dim = station_dim + + self.freq = None + self.temporal_dim = None + self.target_dim = None + self._prepare_data(generators) + + for plt_type in self.allowed_plot_types: + plot_name_tmp = self.plot_name + self.plot_name += '_' + plt_type + self._plot(plt_type=plt_type) + self._save() + self.plot_name = plot_name_tmp + + def _set_dims_from_datahandler(self, data_handler): + self.temporal_dim = data_handler.id_class.time_dim + self.target_dim = data_handler.id_class.target_dim + self.freq = self._get_sampling(data_handler.id_class.sampling) + + @property + def allowed_plot_types(self): + plot_types = ['hist', 'hist_cum'] + return plot_types + + def _prepare_data(self, generators: Dict[str, DataCollection]): + """ + Prepares data to be used by plot methods. + + Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim + """ + avail_data_time_sum = {} + avail_data_station_sum = {} + dataset_time_interval = {} + for subset, generator in generators.items(): + avail_list = [] + for station in generator: + self._set_dims_from_datahandler(data_handler=station) + station_data_x = station.get_X(as_numpy=False)[0] + station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame + self.target_dim: station_data_x[self.target_dim].values[0]}] + station_data_x = self._reduce_dims(station_data_x) + avail_list.append(station_data_x.notnull()) + avail_data = xr.concat(avail_list, dim=self.station_dim).notnull() + avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim) + avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim) + dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray( + avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict' + ) + avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(), + name=self.subset_dim) + ) + full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq) + self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(), + name=self.subset_dim)) + self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index}) + self.dataset_time_interval = dataset_time_interval + + def _reduce_dims(self, dataset): + if len(dataset.dims) > 2: + required = {self.temporal_dim, self.station_dim} + unimportant = set(dataset.dims).difference(required) + sel_dict = {un: dataset[un].values[0] for un in unimportant} + dataset = dataset.loc[sel_dict] + return dataset + + @staticmethod + def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'): + if isinstance(xarray, xr.DataArray): + first = xarray.coords[dim_name].values[0] + last = xarray.coords[dim_name].values[-1] + if return_type == 'as_tuple': + return first, last + elif return_type == 'as_dict': + return {'first': first, 'last': last} + else: + raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'") + else: + raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}") + + @staticmethod + def _make_full_time_index(irregular_time_index, freq): + full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq) + return full_time_index + + def _plot(self, plt_type='hist', *args): + if plt_type == 'hist': + self._plot_hist() + elif plt_type == 'hist_cum': + self._plot_hist_cum() + else: + raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}") + + def _plot_hist(self, *args): + colors = self.get_dataset_colors() + fig, axes = plt.subplots(figsize=(10, 3)) + for i, subset in enumerate(self.dataset_time_interval.keys()): + plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset, + self.temporal_dim: slice( + self.dataset_time_interval[subset]['first'], + self.dataset_time_interval[subset]['last'] + ) + } + ) + + plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset) + plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset]) + + lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), + facecolor='white', framealpha=1, edgecolor='black') + for lgd_line in lgd.get_lines(): + lgd_line.set_linewidth(4.0) + plt.gca().xaxis.set_major_locator(mdates.YearLocator()) + plt.title('') + plt.ylabel('Number of samples') + plt.tight_layout() + + def _plot_hist_cum(self, *args): + colors = self.get_dataset_colors() + fig, axes = plt.subplots(figsize=(10, 3)) + n_bins = int(self.avail_data_cum_sum.max().values) + bins = np.arange(0, n_bins + 1) + descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby( + self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False + ).coords[self.subset_dim].values + + for subset in descending_subsets: + self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes, + bins=bins, + label=subset, + cumulative=-1, + color=colors[subset], + # alpha=.5 + ) + + lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), + facecolor='white', framealpha=1, edgecolor='black') + plt.title('') + plt.ylabel('Number of stations') + plt.xlabel('Number of samples') + plt.xlim((bins[0], bins[-1])) + plt.tight_layout() + + +class PlotPeriodogram(AbstractPlotClass): # pragma: no cover + """ + Create Lomb-Scargle periodogram in raw input and target data. The Lomb-Scargle version can deal with missing values. + + This plot routine is creating the following plots: + + * "raw": data is not aggregated, 1 graph per variable + * "": single data lines are aggregated, 1 graph per variable + * "total": data is aggregated on all variables, single graph + + If data consists on different sampling rates, a separate plot is create for each sampling. + + .. image:: ../../../../../_source/_plots/periodogram.png + :width: 400 + + .. note:: + This plot is not included in the default plot list. To use this plot, add "PlotPeriodogram" to the `plot_list`. + + .. warning:: + This plot is highly sensitive to the data handler structure. Therefore, it is highly likely that this method is + not compatible with any custom data handler. Proven data handlers are `DefaultDataHandler`, + `DataHandlerMixedSampling`, `DataHandlerMixedSamplingWithFilter`. To work properly, the data handler must have + the attribute `.id_class._data`. + + """ + + def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram", + variables_dim="variables", time_dim="datetime", sampling="daily", use_multiprocessing=False): + super().__init__(plot_folder, plot_name) + self.variables_dim = variables_dim + self.time_dim = time_dim + + for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)): + self._sampling = s + self._add_text = {0: "input", 1: "target"}[pos] + multiple, label_names = self._has_filter_dimension(generator[0], pos) + self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing) + self._plot(raw=True) + self._plot(raw=False) + self._plot_total(raw=True) + self._plot_total(raw=False) + if multiple > 1: + self._plot_difference(label_names) + + @staticmethod + def _has_filter_dimension(g, pos): + # check if coords raw data differs from input / target data + check_data = g.id_class + if "filter" not in [check_data.input_data, check_data.target_data][pos].coords.dims: + return 1, [] + else: + if len(set(check_data._data[0].coords.dims).symmetric_difference(check_data.input_data.coords.dims)) > 0: + return g.id_class.input_data.coords["filter"].shape[0], g.id_class.input_data.coords[ + "filter"].values.tolist() + else: + return 1, [] + + @TimeTrackingWrapper + def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False): + """ + Create periodogram data. + """ + self.raw_data = [] + self.plot_data = [] + self.plot_data_raw = [] + self.plot_data_mean = [] + iter = range(multiple if multiple == 1 else multiple + 1) + for m in iter: + plot_data_single = dict() + plot_data_raw_single = dict() + plot_data_mean_single = dict() + raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing) + # raw_data_single = self._prepare_pgram_parallel_var(generator, m, pos, use_multiprocessing) + self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000) + for var in raw_data_single.keys(): + pgram_com = [] + pgram_mean = 0 + all_data = raw_data_single[var] + pgram_mean_raw = np.zeros((len(self.f_index), len(all_data))) + for i, (f, pgram) in enumerate(all_data): + d = np.interp(self.f_index, f, pgram) + pgram_com.append(d) + pgram_mean += d + pgram_mean_raw[:, i] = d + pgram_mean /= len(all_data) + plot_data_single[var] = pgram_com + plot_data_mean_single[var] = (self.f_index, pgram_mean) + plot_data_raw_single[var] = (self.f_index, pgram_mean_raw) + self.plot_data.append(plot_data_single) + self.plot_data_mean.append(plot_data_mean_single) + self.plot_data_raw.append(plot_data_raw_single) + + def _prepare_pgram_parallel_var(self, generator, m, pos, use_multiprocessing): + """Implementation of data preprocessing using parallel variables element processing.""" + raw_data_single = dict() + for g in generator: + if m == 0: + d = g.id_class._data + else: + gd = g.id_class + filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]} + d = (gd.input_data.sel(filter_sel), gd.target_data) + d = d[pos] if isinstance(d, tuple) else d + res = [] + if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution + pool = multiprocessing.Pool( + min([psutil.cpu_count(logical=False), len(d[self.variables_dim].values), + 16])) # use only physical cpus + output = [ + pool.apply_async(f_proc, + args=(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) + for var in d[self.variables_dim].values] + for i, p in enumerate(output): + res.append(p.get()) + pool.close() + else: # serial solution + for var in d[self.variables_dim].values: + res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim))) + for (var_str, f, pgram) in res: + if var_str not in raw_data_single.keys(): + raw_data_single[var_str] = [(f, pgram)] + else: + raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)] + return raw_data_single + + def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing): + """Implementation of data preprocessing using parallel generator element processing.""" + raw_data_single = dict() + res = [] + if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution + pool = multiprocessing.Pool( + min([psutil.cpu_count(logical=False), len(generator), 16])) # use only physical cpus + output = [ + pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim)) + for g in generator] + for i, p in enumerate(output): + res.append(p.get()) + pool.close() + else: + for g in generator: + res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim)) + for res_dict in res: + for k, v in res_dict.items(): + if k not in raw_data_single.keys(): + raw_data_single[k] = v + else: + raw_data_single[k] = raw_data_single[k] + v + return raw_data_single + + @staticmethod + def _add_annotation_line(ax, pos, div, lims, unit): + for p in to_list(pos): # per year + ax.vlines(p / div, *lims, "black") + ax.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor") + + def _format_figure(self, ax, var_name="total"): + """ + Set log scale on both axis, add labels and annotation lines, and set title. + :param ax: current ax object + :param var_name: name of variable that will be included in the title + """ + ax.set_yscale('log') + ax.set_xscale('log') + ax.set_ylabel("power", fontsize='x-large') + ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large') + lims = ax.get_ylim() + self._add_annotation_line(ax, [1, 2, 3], 365.25, lims, "yr") # per year + self._add_annotation_line(ax, 1, 365.25 / 12, lims, "m") # per month + self._add_annotation_line(ax, 1, 7, lims, "w") # per week + self._add_annotation_line(ax, [1, 0.5], 1, lims, "d") # per day + if self._sampling == "hourly": + self._add_annotation_line(ax, 2, 1, lims, "d") # per day + self._add_annotation_line(ax, [1, 0.5], 1 / 24., lims, "h") # per hour + title = f"Periodogram ({var_name})" + ax.set_title(title) + + def _plot(self, raw=True): + plot_path = os.path.join(os.path.abspath(self.plot_folder), + f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}.pdf") + pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) + plot_data = self.plot_data[0] + plot_data_mean = self.plot_data_mean[0] + for var in plot_data.keys(): + fig, ax = plt.subplots() + if raw is True: + for pgram in plot_data[var]: + ax.plot(self.f_index, pgram, "lightblue") + ax.plot(*plot_data_mean[var], "blue") + else: + ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0) + mean = ma.mean().mean(axis=1).values.flatten() + upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() + ax.plot(self.f_index, mean, "blue") + ax.fill_between(self.f_index, lower, upper, color="lightblue") + self._format_figure(ax, var) + pdf_pages.savefig() + # close all open figures / plots + pdf_pages.close() + plt.close('all') + + def _plot_total(self, raw=True): + plot_path = os.path.join(os.path.abspath(self.plot_folder), + f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}_total.pdf") + pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) + plot_data_raw = self.plot_data_raw[0] + fig, ax = plt.subplots() + res = None + for var in plot_data_raw.keys(): + d_var = plot_data_raw[var][1] + res = d_var if res is None else np.concatenate((res, d_var), axis=-1) + if raw is True: + for i in range(res.shape[1]): + ax.plot(self.f_index, res[:, i], "lightblue") + ax.plot(self.f_index, res.mean(axis=1), "blue") + else: + ma = pd.DataFrame(np.vstack(res)).rolling(5, center=True, axis=0) + mean = ma.mean().mean(axis=1).values.flatten() + upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() + ax.plot(self.f_index, mean, "blue") + ax.fill_between(self.f_index, lower, upper, color="lightblue") + self._format_figure(ax, "total") + pdf_pages.savefig() + # close all open figures / plots + pdf_pages.close() + plt.close('all') + + def _plot_difference(self, label_names): + plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter.pdf" + plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) + logging.info(f"... plotting {plot_name}") + pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) + colors = ["blue", "red", "green", "orange", "purple", "black", "grey"] + label_names = ["orig"] + label_names + max_iter = len(self.plot_data) + var_keys = self.plot_data[0].keys() + for var in var_keys: + fig, ax = plt.subplots() + for i in reversed(range(max_iter)): + plot_data = self.plot_data[i] + c = colors[i] + ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0) + mean = ma.mean().mean(axis=1).values.flatten() + ax.plot(self.f_index, mean, c, label=label_names[i]) + if i < 1: + upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten() + ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5, label=None) + self._format_figure(ax, var) + ax.legend(loc="upper center", ncol=max_iter) + pdf_pages.savefig() + # close all open figures / plots + pdf_pages.close() + plt.close('all') + + +def f_proc(var, d_var): + var_str = str(var) + t = (d_var.datetime - d_var.datetime[0]).astype("timedelta64[h]").values / np.timedelta64(1, "D") + f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1).autopower() + return var_str, f, pgram + + +def f_proc_2(g, m, pos, variables_dim, time_dim): + raw_data_single = dict() + if m == 0: + d = g.id_class._data + else: + gd = g.id_class + filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]} + d = (gd.input_data.sel(filter_sel), gd.target_data) + d = d[pos] if isinstance(d, tuple) else d + for var in d[variables_dim].values: + d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim) + var_str, f, pgram = f_proc(var, d_var) + raw_data_single[var_str] = [(f, pgram)] + return raw_data_single diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index d769fabce5702ceb6bb29cf726b4ccf82657ebb4..491aa52e0a9fe0010f77cde315d1f9b7ddb76dfb 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -9,10 +9,7 @@ import warnings from typing import Dict, List, Tuple import matplotlib -import matplotlib.patches as mpatches -import matplotlib.lines as mlines import matplotlib.pyplot as plt -import matplotlib.dates as mdates import numpy as np import pandas as pd import seaborn as sns @@ -22,6 +19,7 @@ from matplotlib.backends.backend_pdf import PdfPages from mlair import helpers from mlair.data_handler.iterator import DataCollection from mlair.helpers import TimeTrackingWrapper +from mlair.plotting.abstract_plot_class import AbstractPlotClass logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -31,100 +29,6 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING) # import matplotlib.pyplot as plt -class AbstractPlotClass: - """ - Abstract class for all plotting routines to unify plot workflow. - - Each inheritance requires a _plot method. Create a plot class like: - - .. code-block:: python - - class MyCustomPlot(AbstractPlotClass): - - def __init__(self, plot_folder, *args, **kwargs): - super().__init__(plot_folder, "custom_plot_name") - self._data = self._prepare_data(*args, **kwargs) - self._plot(*args, **kwargs) - self._save() - - def _prepare_data(*args, **kwargs): - <your custom data preparation> - return data - - def _plot(*args, **kwargs): - <your custom plotting without saving> - - The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are - using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be - set in super class initialisation). - - Methods like the shown _prepare_data() are optional. The only method required to implement is _plot. - - If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot - class. It will log the spent time if you call your plotting without saving the returned object. - - .. code-block:: python - - @TimeTrackingWrapper - class MyCustomPlot(AbstractPlotClass): - pass - - Let's assume it takes a while to create this very special plot. - - >>> MyCustomPlot() - INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss) - - """ - - def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None): - """Set up plot folder and name, and plot resolution (default 500dpi).""" - plot_folder = os.path.abspath(plot_folder) - if not os.path.exists(plot_folder): - os.makedirs(plot_folder) - self.plot_folder = plot_folder - self.plot_name = plot_name - self.resolution = resolution - if rc_params is None: - rc_params = {'axes.labelsize': 'large', - 'xtick.labelsize': 'large', - 'ytick.labelsize': 'large', - 'legend.fontsize': 'large', - 'axes.titlesize': 'large', - } - self.rc_params = rc_params - self._update_rc_params() - - def _plot(self, *args): - """Abstract plot class needs to be implemented in inheritance.""" - raise NotImplementedError - - def _save(self, **kwargs): - """Store plot locally. Name of and path to plot need to be set on initialisation.""" - plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf") - logging.debug(f"... save plot to {plot_name}") - plt.savefig(plot_name, dpi=self.resolution, **kwargs) - plt.close('all') - - def _update_rc_params(self): - plt.rcParams.update(self.rc_params) - - @staticmethod - def _get_sampling(sampling): - if sampling == "daily": - return "D" - elif sampling == "hourly": - return "h" - - @staticmethod - def get_dataset_colors(): - """ - Standard colors used for train-, val-, and test-sets during postprocessing - """ - colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code - return colors - - - @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): """ @@ -230,132 +134,6 @@ class PlotMonthlySummary(AbstractPlotClass): plt.tight_layout() -@TimeTrackingWrapper -class PlotStationMap(AbstractPlotClass): - """ - Plot geographical overview of all used stations as squares. - - Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to - plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored - topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf - - .. image:: ../../../../../_source/_plots/station_map.png - :width: 400 - """ - - def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"): - """ - Set attributes and create plot. - - :param generators: dictionary with the plot color of each data set as key and the generator containing all stations - as value. - :param plot_folder: path to save the plot (default: current directory) - """ - super().__init__(plot_folder, plot_name) - self._ax = None - self._gl = None - self._plot(generators) - self._save(bbox_inches="tight") - - def _draw_background(self): - """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" - - import cartopy.feature as cfeature - - self._ax.add_feature(cfeature.LAND.with_scale("50m")) - self._ax.natural_earth_shp(resolution='50m') - self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') - self._ax.add_feature(cfeature.LAKES.with_scale("50m")) - self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) - self._ax.add_feature(cfeature.RIVERS.with_scale("50m")) - self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') - - def _plot_stations(self, generators): - """ - Loop over all keys in generators dict and its containing stations and plot the stations's position. - - Position is highlighted by a square on the map regarding the given color. - - :param generators: dictionary with the plot color of each data set as key and the generator containing all - stations as value. - """ - - import cartopy.crs as ccrs - if generators is not None: - legend_elements = [] - default_colors = self.get_dataset_colors() - for element in generators: - data_collection, plot_opts = self._get_collection_and_opts(element) - name = data_collection.name or "unknown" - marker = plot_opts.get("marker", "s") - ms = plot_opts.get("ms", 6) - mec = plot_opts.get("mec", "k") - mfc = plot_opts.get("mfc", default_colors.get(name, "b")) - legend_elements.append( - mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None', - label=f"{name} ({len(data_collection)})")) - for station in data_collection: - coords = station.get_coordinates() - IDx, IDy = coords["lon"], coords["lat"] - self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree()) - if len(legend_elements) > 0: - self._ax.legend(handles=legend_elements, loc='best') - - @staticmethod - def _adjust_marker(marker): - _adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"} - if isinstance(marker, int) and marker in _adjust.keys(): - return _adjust[marker] - else: - return marker - - @staticmethod - def _get_collection_and_opts(element): - if isinstance(element, tuple): - if len(element) == 1: - return element[0], {} - else: - return element - else: - return element, {} - - def _plot(self, generators: List): - """ - Create the station map plot. - - Set figure and call all required sub-methods. - - :param generators: dictionary with the plot color of each data set as key and the generator containing all - stations as value. - """ - - import cartopy.crs as ccrs - from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER - fig = plt.figure(figsize=(10, 5)) - self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) - self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True) - self._gl.xformatter = LONGITUDE_FORMATTER - self._gl.yformatter = LATITUDE_FORMATTER - self._draw_background() - self._plot_stations(generators) - self._adjust_extent() - plt.tight_layout() - - def _adjust_extent(self): - import cartopy.crs as ccrs - - def diff(arr): - return arr[1] - arr[0], arr[3] - arr[2] - - def find_ratio(delta, reference=5): - return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5) - - extent = self._ax.get_extent(crs=ccrs.PlateCarree()) - ratio = find_ratio(diff(extent)) - new_extent = extent + np.array([-1, 1, -1, 1]) * ratio - self._ax.set_extent(new_extent, crs=ccrs.PlateCarree()) - - @TimeTrackingWrapper class PlotConditionalQuantiles(AbstractPlotClass): """ @@ -1138,133 +916,6 @@ class PlotTimeSeries: return matplotlib.backends.backend_pdf.PdfPages(plot_name) -@TimeTrackingWrapper -class PlotAvailability(AbstractPlotClass): - """ - Create data availablility plot similar to Gantt plot. - - Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal - resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a - colored bar or a blank space. - - You can set different colors to highlight subsets for example by providing different generators for the same index - using different keys in the input dictionary. - - Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs - in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset. - - Calling this class will create three versions fo the availability plot. - - 1) Data availability for each element - 1) Data availability as summary over all elements (is there at least a single elemnt for each time step) - 1) Combination of single and overall availability - - .. image:: ../../../../../_source/_plots/data_availability.png - :width: 400 - - .. image:: ../../../../../_source/_plots/data_availability_summary.png - :width: 400 - - .. image:: ../../../../../_source/_plots/data_availability_combined.png - :width: 400 - - """ - - def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily", - summary_name="data availability", time_dimension="datetime", window_dimension="window"): - """Initialise.""" - # create standard Gantt plot for all stations (currently in single pdf file with single page) - super().__init__(plot_folder, "data_availability") - self.time_dim = time_dimension - self.window_dim = window_dimension - self.sampling = self._get_sampling(sampling) - self.linewidth = None - if self.sampling == 'h': - self.linewidth = 0.001 - plot_dict = self._prepare_data(generators) - lgd = self._plot(plot_dict) - self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") - # create summary Gantt plot (is data in at least one station available) - self.plot_name += "_summary" - plot_dict_summary = self._summarise_data(generators, summary_name) - lgd = self._plot(plot_dict_summary) - self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") - # combination of station and summary plot, last element is summary broken bar - self.plot_name = "data_availability_combined" - plot_dict_summary.update(plot_dict) - lgd = self._plot(plot_dict_summary) - self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") - - def _prepare_data(self, generators: Dict[str, DataCollection]): - plt_dict = {} - for subset, data_collection in generators.items(): - for station in data_collection: - labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() - labels_bool = labels.sel(**{self.window_dim: 1}).notnull() - group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum() - plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, - index=labels.coords[self.time_dim].values) - t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) - t2 = [i[1:] for i in t if i[0]] - - if plt_dict.get(str(station)) is None: - plt_dict[str(station)] = {subset: t2} - else: - plt_dict[str(station)].update({subset: t2}) - return plt_dict - - def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str): - plt_dict = {} - for subset, data_collection in generators.items(): - all_data = None - for station in data_collection: - labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean() - labels_bool = labels.sel(**{self.window_dim: 1}).notnull() - if all_data is None: - all_data = labels_bool - else: - tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords - all_data = np.logical_or(tmp, labels_bool).combine_first( - all_data) # apply logical on merge and fill missing with all_data - - group = (all_data != all_data.shift({self.time_dim: 1})).cumsum() - plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, - index=all_data.coords[self.time_dim].values) - t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) - t2 = [i[1:] for i in t if i[0]] - if plt_dict.get(summary_name) is None: - plt_dict[summary_name] = {subset: t2} - else: - plt_dict[summary_name].update({subset: t2}) - return plt_dict - - def _plot(self, plt_dict): - colors = self.get_dataset_colors() - _used_colors = [] - pos = 0 - height = 0.8 # should be <= 1 - yticklabels = [] - number_of_stations = len(plt_dict.keys()) - fig, ax = plt.subplots(figsize=(10, number_of_stations / 3)) - for station, d in sorted(plt_dict.items(), reverse=True): - pos += 1 - for subset, color in colors.items(): - plt_data = d.get(subset) - if plt_data is None: - continue - elif color not in _used_colors: # this is required for a proper legend creation - _used_colors.append(color) - ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth) - yticklabels.append(station) - - ax.set_ylim([height, number_of_stations + 1]) - ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) - ax.set_yticklabels(yticklabels) - handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors] - lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) - return lgd - - @TimeTrackingWrapper class PlotSeparationOfScales(AbstractPlotClass): @@ -1292,178 +943,6 @@ class PlotSeparationOfScales(AbstractPlotClass): self._save() -@TimeTrackingWrapper -class PlotAvailabilityHistogram(AbstractPlotClass): - """ - Create data availability plots as histogram. - - Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean). - Calling this class creates two different types of histograms where each generator - - 1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis) - 2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number - of samples (yaxis) - - .. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png - :width: 400 - - .. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png - :width: 400 - - """ - - def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", - subset_dim: str = 'DataSet', history_dim: str = 'window', - station_dim: str = 'Stations',): - - super().__init__(plot_folder, "data_availability_histogram") - - self.subset_dim = subset_dim - self.history_dim = history_dim - self.station_dim = station_dim - - self.freq = None - self.temporal_dim = None - self.target_dim = None - self._prepare_data(generators) - - for plt_type in self.allowed_plot_types: - plot_name_tmp = self.plot_name - self.plot_name += '_' + plt_type - self._plot(plt_type=plt_type) - self._save() - self.plot_name = plot_name_tmp - - def _set_dims_from_datahandler(self, data_handler): - self.temporal_dim = data_handler.id_class.time_dim - self.target_dim = data_handler.id_class.target_dim - self.freq = self._get_sampling(data_handler.id_class.sampling) - - @property - def allowed_plot_types(self): - plot_types = ['hist', 'hist_cum'] - return plot_types - - def _prepare_data(self, generators: Dict[str, DataCollection]): - """ - Prepares data to be used by plot methods. - - Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim - """ - avail_data_time_sum = {} - avail_data_station_sum = {} - dataset_time_interval = {} - for subset, generator in generators.items(): - avail_list = [] - for station in generator: - self._set_dims_from_datahandler(data_handler=station) - station_data_x = station.get_X(as_numpy=False)[0] - station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame - self.target_dim: station_data_x[self.target_dim].values[0]}] - station_data_x = self._reduce_dims(station_data_x) - avail_list.append(station_data_x.notnull()) - avail_data = xr.concat(avail_list, dim=self.station_dim).notnull() - avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim) - avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim) - dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray( - avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict' - ) - avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(), - name=self.subset_dim) - ) - full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq) - self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(), - name=self.subset_dim)) - self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index}) - self.dataset_time_interval = dataset_time_interval - - def _reduce_dims(self, dataset): - if len(dataset.dims) > 2: - required = {self.temporal_dim, self.station_dim} - unimportant = set(dataset.dims).difference(required) - sel_dict = {un: dataset[un].values[0] for un in unimportant} - dataset = dataset.loc[sel_dict] - return dataset - - @staticmethod - def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'): - if isinstance(xarray, xr.DataArray): - first = xarray.coords[dim_name].values[0] - last = xarray.coords[dim_name].values[-1] - if return_type == 'as_tuple': - return first, last - elif return_type == 'as_dict': - return {'first': first, 'last': last} - else: - raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'") - else: - raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}") - - @staticmethod - def _make_full_time_index(irregular_time_index, freq): - full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq) - return full_time_index - - def _plot(self, plt_type='hist', *args): - if plt_type == 'hist': - self._plot_hist() - elif plt_type == 'hist_cum': - self._plot_hist_cum() - else: - raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}") - - def _plot_hist(self, *args): - colors = self.get_dataset_colors() - fig, axes = plt.subplots(figsize=(10, 3)) - for i, subset in enumerate(self.dataset_time_interval.keys()): - plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset, - self.temporal_dim: slice( - self.dataset_time_interval[subset]['first'], - self.dataset_time_interval[subset]['last'] - ) - } - ) - - plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset) - plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset]) - - lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), - facecolor='white', framealpha=1, edgecolor='black') - for lgd_line in lgd.get_lines(): - lgd_line.set_linewidth(4.0) - plt.gca().xaxis.set_major_locator(mdates.YearLocator()) - plt.title('') - plt.ylabel('Number of samples') - plt.tight_layout() - - def _plot_hist_cum(self, *args): - colors = self.get_dataset_colors() - fig, axes = plt.subplots(figsize=(10, 3)) - n_bins = int(self.avail_data_cum_sum.max().values) - bins = np.arange(0, n_bins+1) - descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby( - self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False - ).coords[self.subset_dim].values - - for subset in descending_subsets: - self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes, - bins=bins, - label=subset, - cumulative=-1, - color=colors[subset], - # alpha=.5 - ) - - lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval), - facecolor='white', framealpha=1, edgecolor='black') - plt.title('') - plt.ylabel('Number of stations') - plt.xlabel('Number of samples') - plt.xlim((bins[0], bins[-1])) - plt.tight_layout() - - - if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 73aebb008ebf1f61eb2878293fc160cf549d19cb..23d26fc1e5c866657d28b11f275d76df5a8cc300 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -19,9 +19,10 @@ from mlair.helpers.datastore import NameNotFoundInDataStore from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list, tables from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass -from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotAvailabilityHistogram, \ - PlotConditionalQuantiles, PlotSeparationOfScales +from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ + PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, PlotSeparationOfScales +from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ + PlotPeriodogram from mlair.run_modules.run_environment import RunEnvironment @@ -296,6 +297,7 @@ class PostProcessing(RunEnvironment): """ logging.info("Run plotting routines...") path = self.data_store.get("forecast_path") + use_multiprocessing = self.data_store.get("use_multiprocessing") plot_list = self.data_store.get("plot_list", "postprocessing") time_dim = self.data_store.get("time_dim") @@ -325,23 +327,6 @@ class PostProcessing(RunEnvironment): except Exception as e: logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error: {e}") - try: - if "PlotStationMap" in plot_list: - if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get( - "hostname")[:6] in self.data_store.get("hpc_hosts"): - logging.warning( - f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") - else: - gens = [(self.train_data, {"marker": 5, "ms": 9}), - (self.val_data, {"marker": 6, "ms": 9}), - (self.test_data, {"marker": 4, "ms": 9})] - PlotStationMap(generators=gens, plot_folder=self.plot_path) - gens = [(self.train_val_data, {"marker": 8, "ms": 9}), - (self.test_data, {"marker": 9, "ms": 9})] - PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") - except Exception as e: - logging.error(f"Could not create plot PlotStationMap due to the following error: {e}") - try: if "PlotMonthlySummary" in plot_list: PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var, @@ -372,6 +357,23 @@ class PostProcessing(RunEnvironment): except Exception as e: logging.error(f"Could not create plot PlotTimeSeries due to the following error: {e}") + try: + if "PlotStationMap" in plot_list: + if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get( + "hostname")[:6] in self.data_store.get("hpc_hosts"): + logging.warning( + f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}") + else: + gens = [(self.train_data, {"marker": 5, "ms": 9}), + (self.val_data, {"marker": 6, "ms": 9}), + (self.test_data, {"marker": 4, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path) + gens = [(self.train_val_data, {"marker": 8, "ms": 9}), + (self.test_data, {"marker": 9, "ms": 9})] + PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var") + except Exception as e: + logging.error(f"Could not create plot PlotStationMap due to the following error: {e}") + try: if "PlotAvailability" in plot_list: avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data} @@ -388,6 +390,14 @@ class PostProcessing(RunEnvironment): except Exception as e: logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}") + try: + if "PlotPeriodogram" in plot_list: + PlotPeriodogram(self.train_data, plot_folder=self.plot_path, time_dim=time_dim, + variables_dim=target_dim, sampling=self._sampling, + use_multiprocessing=use_multiprocessing) + except Exception as e: + logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}") + def calculate_test_score(self): """Evaluate test score of model and save locally.""" diff --git a/mlair/run_modules/pre_processing.py b/mlair/run_modules/pre_processing.py index f59a4e89ced738c9198619ec0d117df2edf3ee93..68164b6fa3c6b95727f634baebd40e988482abd5 100644 --- a/mlair/run_modules/pre_processing.py +++ b/mlair/run_modules/pre_processing.py @@ -10,7 +10,6 @@ import multiprocessing import requests import psutil -import numpy as np import pandas as pd from mlair.data_handler import DataCollection, AbstractDataHandler @@ -257,6 +256,7 @@ class PreProcessing(RunEnvironment): if dh is not None: collection.add(dh) valid_stations.append(s) + pool.close() else: # serial solution logging.info("use serial validate station approach") for station in set_stations: diff --git a/requirements.txt b/requirements.txt index 85655e237f8e10e98f77c379be6acd0a7bb65d46..dba565fbb535db7d7782baec8690971d4393b3e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ absl-py==0.11.0 appdirs==1.4.4 astor==0.8.1 +astropy==4.1 attrs==20.3.0 bottleneck==1.3.2 cached-property==1.5.2 diff --git a/requirements_gpu.txt b/requirements_gpu.txt index cc189496bdf4e1e1ee86902a1953c2058d58c8e4..f170e1b7b67df7e17a3258ca849b252acaf3e650 100644 --- a/requirements_gpu.txt +++ b/requirements_gpu.txt @@ -1,6 +1,7 @@ absl-py==0.11.0 appdirs==1.4.4 astor==0.8.1 +astropy==4.1 attrs==20.3.0 bottleneck==1.3.2 cached-property==1.5.2