Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • esde/machine-learning/mlair
1 result
Select Git revision
Show changes
Commits on Source (12)
absl-py==0.11.0
appdirs==1.4.4
astor==0.8.1
astropy==4.1
attrs==20.3.0
bottleneck==1.3.2
cached-property==1.5.2
......
absl-py==0.11.0
appdirs==1.4.4
astor==0.8.1
astropy==4.1
attrs==20.3.0
bottleneck==1.3.2
cached-property==1.5.2
......
docs/_source/_plots/periodogram.png

62.3 KiB

......@@ -56,7 +56,7 @@ class DataHandlerMixedSamplingSingleStation(DataHandlerSingleStation):
kwargs.update({parameter_name: parameter})
def make_input_target(self):
self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self.set_inputs_and_targets()
def load_and_interpolate(self, ind) -> [xr.DataArray, pd.DataFrame]:
......@@ -110,7 +110,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
A KZ filter is applied on the input data that has hourly resolution. Lables Y are provided as aggregated values
with daily resolution.
"""
self._data = list(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self._data = tuple(map(self.load_and_interpolate, [0, 1])) # load input (0) and target (1) data
self.set_inputs_and_targets()
self.apply_kz_filter()
......@@ -158,7 +158,7 @@ class DataHandlerMixedSamplingWithFilterSingleStation(DataHandlerMixedSamplingSi
def _extract_lazy(self, lazy_data):
_data, self.meta, _input_data, _target_data, self.cutoff_period, self.cutoff_period_days = lazy_data
start_inp, end_inp = self.update_start_end(0)
self._data = list(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1]))
self._data = tuple(map(lambda x: self._slice_prep(_data[x], *self.update_start_end(x)), [0, 1]))
self.input_data = self._slice_prep(_input_data, start_inp, end_inp)
self.target_data = self._slice_prep(_target_data, self.start, self.end)
......
......@@ -299,6 +299,7 @@ class DefaultDataHandler(AbstractDataHandler):
for p in output:
dh, s = p.get()
_inner()
pool.close()
else: # serial solution
logging.info("use serial transformation approach")
for station in set_stations:
......
"""Abstract plot class that should be used for preprocessing and postprocessing plots."""
__author__ = "Lukas Leufen"
__date__ = '2021-04-13'
import logging
import os
from matplotlib import pyplot as plt
class AbstractPlotClass:
"""
Abstract class for all plotting routines to unify plot workflow.
Each inheritance requires a _plot method. Create a plot class like:
.. code-block:: python
class MyCustomPlot(AbstractPlotClass):
def __init__(self, plot_folder, *args, **kwargs):
super().__init__(plot_folder, "custom_plot_name")
self._data = self._prepare_data(*args, **kwargs)
self._plot(*args, **kwargs)
self._save()
def _prepare_data(*args, **kwargs):
<your custom data preparation>
return data
def _plot(*args, **kwargs):
<your custom plotting without saving>
The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are
using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be
set in super class initialisation).
Methods like the shown _prepare_data() are optional. The only method required to implement is _plot.
If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot
class. It will log the spent time if you call your plotting without saving the returned object.
.. code-block:: python
@TimeTrackingWrapper
class MyCustomPlot(AbstractPlotClass):
pass
Let's assume it takes a while to create this very special plot.
>>> MyCustomPlot()
INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss)
"""
def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None):
"""Set up plot folder and name, and plot resolution (default 500dpi)."""
plot_folder = os.path.abspath(plot_folder)
if not os.path.exists(plot_folder):
os.makedirs(plot_folder)
self.plot_folder = plot_folder
self.plot_name = plot_name
self.resolution = resolution
if rc_params is None:
rc_params = {'axes.labelsize': 'large',
'xtick.labelsize': 'large',
'ytick.labelsize': 'large',
'legend.fontsize': 'large',
'axes.titlesize': 'large',
}
self.rc_params = rc_params
self._update_rc_params()
def _plot(self, *args):
"""Abstract plot class needs to be implemented in inheritance."""
raise NotImplementedError
def _save(self, **kwargs):
"""Store plot locally. Name of and path to plot need to be set on initialisation."""
plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=self.resolution, **kwargs)
plt.close('all')
def _update_rc_params(self):
plt.rcParams.update(self.rc_params)
@staticmethod
def _get_sampling(sampling):
if sampling == "daily":
return "D"
elif sampling == "hourly":
return "h"
@staticmethod
def get_dataset_colors():
"""
Standard colors used for train-, val-, and test-sets during postprocessing
"""
colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code
return colors
"""Collection of plots to get more insight into data."""
__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2021-04-13'
from typing import List, Dict
import os
import logging
import multiprocessing
import psutil
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib
from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates
from astropy.timeseries import LombScargle
from mlair.data_handler import DataCollection
from mlair.helpers import TimeTrackingWrapper, to_list
from mlair.plotting.abstract_plot_class import AbstractPlotClass
@TimeTrackingWrapper
class PlotStationMap(AbstractPlotClass): # pragma: no cover
"""
Plot geographical overview of all used stations as squares.
Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to
plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored
topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf
.. image:: ../../../../../_source/_plots/station_map.png
:width: 400
"""
def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"):
"""
Set attributes and create plot.
:param generators: dictionary with the plot color of each data set as key and the generator containing all stations
as value.
:param plot_folder: path to save the plot (default: current directory)
"""
super().__init__(plot_folder, plot_name)
self._ax = None
self._gl = None
self._plot(generators)
self._save(bbox_inches="tight")
def _draw_background(self):
"""Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
import cartopy.feature as cfeature
self._ax.add_feature(cfeature.LAND.with_scale("50m"))
self._ax.natural_earth_shp(resolution='50m')
self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
def _plot_stations(self, generators):
"""
Loop over all keys in generators dict and its containing stations and plot the stations's position.
Position is highlighted by a square on the map regarding the given color.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
import cartopy.crs as ccrs
if generators is not None:
legend_elements = []
default_colors = self.get_dataset_colors()
for element in generators:
data_collection, plot_opts = self._get_collection_and_opts(element)
name = data_collection.name or "unknown"
marker = plot_opts.get("marker", "s")
ms = plot_opts.get("ms", 6)
mec = plot_opts.get("mec", "k")
mfc = plot_opts.get("mfc", default_colors.get(name, "b"))
legend_elements.append(
mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None',
label=f"{name} ({len(data_collection)})"))
for station in data_collection:
coords = station.get_coordinates()
IDx, IDy = coords["lon"], coords["lat"]
self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
if len(legend_elements) > 0:
self._ax.legend(handles=legend_elements, loc='best')
@staticmethod
def _adjust_marker(marker):
_adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"}
if isinstance(marker, int) and marker in _adjust.keys():
return _adjust[marker]
else:
return marker
@staticmethod
def _get_collection_and_opts(element):
if isinstance(element, tuple):
if len(element) == 1:
return element[0], {}
else:
return element
else:
return element, {}
def _plot(self, generators: List):
"""
Create the station map plot.
Set figure and call all required sub-methods.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
fig = plt.figure(figsize=(10, 5))
self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
self._gl.xformatter = LONGITUDE_FORMATTER
self._gl.yformatter = LATITUDE_FORMATTER
self._draw_background()
self._plot_stations(generators)
self._adjust_extent()
plt.tight_layout()
def _adjust_extent(self):
import cartopy.crs as ccrs
def diff(arr):
return arr[1] - arr[0], arr[3] - arr[2]
def find_ratio(delta, reference=5):
return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
extent = self._ax.get_extent(crs=ccrs.PlateCarree())
ratio = find_ratio(diff(extent))
new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
self._ax.set_extent(new_extent, crs=ccrs.PlateCarree())
@TimeTrackingWrapper
class PlotAvailability(AbstractPlotClass): # pragma: no cover
"""
Create data availablility plot similar to Gantt plot.
Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal
resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a
colored bar or a blank space.
You can set different colors to highlight subsets for example by providing different generators for the same index
using different keys in the input dictionary.
Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs
in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset.
Calling this class will create three versions fo the availability plot.
1) Data availability for each element
1) Data availability as summary over all elements (is there at least a single elemnt for each time step)
1) Combination of single and overall availability
.. image:: ../../../../../_source/_plots/data_availability.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_summary.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_combined.png
:width: 400
"""
def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
summary_name="data availability", time_dimension="datetime", window_dimension="window"):
"""Initialise."""
# create standard Gantt plot for all stations (currently in single pdf file with single page)
super().__init__(plot_folder, "data_availability")
self.time_dim = time_dimension
self.window_dim = window_dimension
self.sampling = self._get_sampling(sampling)
self.linewidth = None
if self.sampling == 'h':
self.linewidth = 0.001
plot_dict = self._prepare_data(generators)
lgd = self._plot(plot_dict)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
# create summary Gantt plot (is data in at least one station available)
self.plot_name += "_summary"
plot_dict_summary = self._summarise_data(generators, summary_name)
lgd = self._plot(plot_dict_summary)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
# combination of station and summary plot, last element is summary broken bar
self.plot_name = "data_availability_combined"
plot_dict_summary.update(plot_dict)
lgd = self._plot(plot_dict_summary)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
def _prepare_data(self, generators: Dict[str, DataCollection]):
plt_dict = {}
for subset, data_collection in generators.items():
for station in data_collection:
labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum()
plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
index=labels.coords[self.time_dim].values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(str(station)) is None:
plt_dict[str(station)] = {subset: t2}
else:
plt_dict[str(station)].update({subset: t2})
return plt_dict
def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
plt_dict = {}
for subset, data_collection in generators.items():
all_data = None
for station in data_collection:
labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
if all_data is None:
all_data = labels_bool
else:
tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords
all_data = np.logical_or(tmp, labels_bool).combine_first(
all_data) # apply logical on merge and fill missing with all_data
group = (all_data != all_data.shift({self.time_dim: 1})).cumsum()
plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
index=all_data.coords[self.time_dim].values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(summary_name) is None:
plt_dict[summary_name] = {subset: t2}
else:
plt_dict[summary_name].update({subset: t2})
return plt_dict
def _plot(self, plt_dict):
colors = self.get_dataset_colors()
_used_colors = []
pos = 0
height = 0.8 # should be <= 1
yticklabels = []
number_of_stations = len(plt_dict.keys())
fig, ax = plt.subplots(figsize=(10, number_of_stations / 3))
for station, d in sorted(plt_dict.items(), reverse=True):
pos += 1
for subset, color in colors.items():
plt_data = d.get(subset)
if plt_data is None:
continue
elif color not in _used_colors: # this is required for a proper legend creation
_used_colors.append(color)
ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth)
yticklabels.append(station)
ax.set_ylim([height, number_of_stations + 1])
ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2)
ax.set_yticklabels(yticklabels)
handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors]
lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
return lgd
@TimeTrackingWrapper
class PlotAvailabilityHistogram(AbstractPlotClass): # pragma: no cover
"""
Create data availability plots as histogram.
Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean).
Calling this class creates two different types of histograms where each generator
1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis)
2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number
of samples (yaxis)
.. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png
:width: 400
"""
def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".",
subset_dim: str = 'DataSet', history_dim: str = 'window',
station_dim: str = 'Stations', ):
super().__init__(plot_folder, "data_availability_histogram")
self.subset_dim = subset_dim
self.history_dim = history_dim
self.station_dim = station_dim
self.freq = None
self.temporal_dim = None
self.target_dim = None
self._prepare_data(generators)
for plt_type in self.allowed_plot_types:
plot_name_tmp = self.plot_name
self.plot_name += '_' + plt_type
self._plot(plt_type=plt_type)
self._save()
self.plot_name = plot_name_tmp
def _set_dims_from_datahandler(self, data_handler):
self.temporal_dim = data_handler.id_class.time_dim
self.target_dim = data_handler.id_class.target_dim
self.freq = self._get_sampling(data_handler.id_class.sampling)
@property
def allowed_plot_types(self):
plot_types = ['hist', 'hist_cum']
return plot_types
def _prepare_data(self, generators: Dict[str, DataCollection]):
"""
Prepares data to be used by plot methods.
Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim
"""
avail_data_time_sum = {}
avail_data_station_sum = {}
dataset_time_interval = {}
for subset, generator in generators.items():
avail_list = []
for station in generator:
self._set_dims_from_datahandler(data_handler=station)
station_data_x = station.get_X(as_numpy=False)[0]
station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame
self.target_dim: station_data_x[self.target_dim].values[0]}]
station_data_x = self._reduce_dims(station_data_x)
avail_list.append(station_data_x.notnull())
avail_data = xr.concat(avail_list, dim=self.station_dim).notnull()
avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim)
avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim)
dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray(
avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict'
)
avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(),
name=self.subset_dim)
)
full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq)
self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(),
name=self.subset_dim))
self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index})
self.dataset_time_interval = dataset_time_interval
def _reduce_dims(self, dataset):
if len(dataset.dims) > 2:
required = {self.temporal_dim, self.station_dim}
unimportant = set(dataset.dims).difference(required)
sel_dict = {un: dataset[un].values[0] for un in unimportant}
dataset = dataset.loc[sel_dict]
return dataset
@staticmethod
def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'):
if isinstance(xarray, xr.DataArray):
first = xarray.coords[dim_name].values[0]
last = xarray.coords[dim_name].values[-1]
if return_type == 'as_tuple':
return first, last
elif return_type == 'as_dict':
return {'first': first, 'last': last}
else:
raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'")
else:
raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}")
@staticmethod
def _make_full_time_index(irregular_time_index, freq):
full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq)
return full_time_index
def _plot(self, plt_type='hist', *args):
if plt_type == 'hist':
self._plot_hist()
elif plt_type == 'hist_cum':
self._plot_hist_cum()
else:
raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}")
def _plot_hist(self, *args):
colors = self.get_dataset_colors()
fig, axes = plt.subplots(figsize=(10, 3))
for i, subset in enumerate(self.dataset_time_interval.keys()):
plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset,
self.temporal_dim: slice(
self.dataset_time_interval[subset]['first'],
self.dataset_time_interval[subset]['last']
)
}
)
plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset)
plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset])
lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
facecolor='white', framealpha=1, edgecolor='black')
for lgd_line in lgd.get_lines():
lgd_line.set_linewidth(4.0)
plt.gca().xaxis.set_major_locator(mdates.YearLocator())
plt.title('')
plt.ylabel('Number of samples')
plt.tight_layout()
def _plot_hist_cum(self, *args):
colors = self.get_dataset_colors()
fig, axes = plt.subplots(figsize=(10, 3))
n_bins = int(self.avail_data_cum_sum.max().values)
bins = np.arange(0, n_bins + 1)
descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby(
self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False
).coords[self.subset_dim].values
for subset in descending_subsets:
self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes,
bins=bins,
label=subset,
cumulative=-1,
color=colors[subset],
# alpha=.5
)
lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
facecolor='white', framealpha=1, edgecolor='black')
plt.title('')
plt.ylabel('Number of stations')
plt.xlabel('Number of samples')
plt.xlim((bins[0], bins[-1]))
plt.tight_layout()
class PlotPeriodogram(AbstractPlotClass): # pragma: no cover
"""
Create Lomb-Scargle periodogram in raw input and target data. The Lomb-Scargle version can deal with missing values.
This plot routine is creating the following plots:
* "raw": data is not aggregated, 1 graph per variable
* "": single data lines are aggregated, 1 graph per variable
* "total": data is aggregated on all variables, single graph
If data consists on different sampling rates, a separate plot is create for each sampling.
.. image:: ../../../../../_source/_plots/periodogram.png
:width: 400
.. note::
This plot is not included in the default plot list. To use this plot, add "PlotPeriodogram" to the `plot_list`.
.. warning::
This plot is highly sensitive to the data handler structure. Therefore, it is highly likely that this method is
not compatible with any custom data handler. Proven data handlers are `DefaultDataHandler`,
`DataHandlerMixedSampling`, `DataHandlerMixedSamplingWithFilter`. To work properly, the data handler must have
the attribute `.id_class._data`.
"""
def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram",
variables_dim="variables", time_dim="datetime", sampling="daily", use_multiprocessing=False):
super().__init__(plot_folder, plot_name)
self.variables_dim = variables_dim
self.time_dim = time_dim
for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)):
self._sampling = s
self._add_text = {0: "input", 1: "target"}[pos]
multiple, label_names = self._has_filter_dimension(generator[0], pos)
self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing)
self._plot(raw=True)
self._plot(raw=False)
self._plot_total(raw=True)
self._plot_total(raw=False)
if multiple > 1:
self._plot_difference(label_names)
@staticmethod
def _has_filter_dimension(g, pos):
# check if coords raw data differs from input / target data
check_data = g.id_class
if "filter" not in [check_data.input_data, check_data.target_data][pos].coords.dims:
return 1, []
else:
if len(set(check_data._data[0].coords.dims).symmetric_difference(check_data.input_data.coords.dims)) > 0:
return g.id_class.input_data.coords["filter"].shape[0], g.id_class.input_data.coords[
"filter"].values.tolist()
else:
return 1, []
@TimeTrackingWrapper
def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False):
"""
Create periodogram data.
"""
self.raw_data = []
self.plot_data = []
self.plot_data_raw = []
self.plot_data_mean = []
iter = range(multiple if multiple == 1 else multiple + 1)
for m in iter:
plot_data_single = dict()
plot_data_raw_single = dict()
plot_data_mean_single = dict()
raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing)
# raw_data_single = self._prepare_pgram_parallel_var(generator, m, pos, use_multiprocessing)
self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000)
for var in raw_data_single.keys():
pgram_com = []
pgram_mean = 0
all_data = raw_data_single[var]
pgram_mean_raw = np.zeros((len(self.f_index), len(all_data)))
for i, (f, pgram) in enumerate(all_data):
d = np.interp(self.f_index, f, pgram)
pgram_com.append(d)
pgram_mean += d
pgram_mean_raw[:, i] = d
pgram_mean /= len(all_data)
plot_data_single[var] = pgram_com
plot_data_mean_single[var] = (self.f_index, pgram_mean)
plot_data_raw_single[var] = (self.f_index, pgram_mean_raw)
self.plot_data.append(plot_data_single)
self.plot_data_mean.append(plot_data_mean_single)
self.plot_data_raw.append(plot_data_raw_single)
def _prepare_pgram_parallel_var(self, generator, m, pos, use_multiprocessing):
"""Implementation of data preprocessing using parallel variables element processing."""
raw_data_single = dict()
for g in generator:
if m == 0:
d = g.id_class._data
else:
gd = g.id_class
filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]}
d = (gd.input_data.sel(filter_sel), gd.target_data)
d = d[pos] if isinstance(d, tuple) else d
res = []
if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution
pool = multiprocessing.Pool(
min([psutil.cpu_count(logical=False), len(d[self.variables_dim].values),
16])) # use only physical cpus
output = [
pool.apply_async(f_proc,
args=(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
for var in d[self.variables_dim].values]
for i, p in enumerate(output):
res.append(p.get())
pool.close()
else: # serial solution
for var in d[self.variables_dim].values:
res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
for (var_str, f, pgram) in res:
if var_str not in raw_data_single.keys():
raw_data_single[var_str] = [(f, pgram)]
else:
raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)]
return raw_data_single
def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing):
"""Implementation of data preprocessing using parallel generator element processing."""
raw_data_single = dict()
res = []
if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution
pool = multiprocessing.Pool(
min([psutil.cpu_count(logical=False), len(generator), 16])) # use only physical cpus
output = [
pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim))
for g in generator]
for i, p in enumerate(output):
res.append(p.get())
pool.close()
else:
for g in generator:
res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim))
for res_dict in res:
for k, v in res_dict.items():
if k not in raw_data_single.keys():
raw_data_single[k] = v
else:
raw_data_single[k] = raw_data_single[k] + v
return raw_data_single
@staticmethod
def _add_annotation_line(ax, pos, div, lims, unit):
for p in to_list(pos): # per year
ax.vlines(p / div, *lims, "black")
ax.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor")
def _format_figure(self, ax, var_name="total"):
"""
Set log scale on both axis, add labels and annotation lines, and set title.
:param ax: current ax object
:param var_name: name of variable that will be included in the title
"""
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_ylabel("power", fontsize='x-large')
ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large')
lims = ax.get_ylim()
self._add_annotation_line(ax, [1, 2, 3], 365.25, lims, "yr") # per year
self._add_annotation_line(ax, 1, 365.25 / 12, lims, "m") # per month
self._add_annotation_line(ax, 1, 7, lims, "w") # per week
self._add_annotation_line(ax, [1, 0.5], 1, lims, "d") # per day
if self._sampling == "hourly":
self._add_annotation_line(ax, 2, 1, lims, "d") # per day
self._add_annotation_line(ax, [1, 0.5], 1 / 24., lims, "h") # per hour
title = f"Periodogram ({var_name})"
ax.set_title(title)
def _plot(self, raw=True):
plot_path = os.path.join(os.path.abspath(self.plot_folder),
f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}.pdf")
pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
plot_data = self.plot_data[0]
plot_data_mean = self.plot_data_mean[0]
for var in plot_data.keys():
fig, ax = plt.subplots()
if raw is True:
for pgram in plot_data[var]:
ax.plot(self.f_index, pgram, "lightblue")
ax.plot(*plot_data_mean[var], "blue")
else:
ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
mean = ma.mean().mean(axis=1).values.flatten()
upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
ax.plot(self.f_index, mean, "blue")
ax.fill_between(self.f_index, lower, upper, color="lightblue")
self._format_figure(ax, var)
pdf_pages.savefig()
# close all open figures / plots
pdf_pages.close()
plt.close('all')
def _plot_total(self, raw=True):
plot_path = os.path.join(os.path.abspath(self.plot_folder),
f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}_total.pdf")
pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
plot_data_raw = self.plot_data_raw[0]
fig, ax = plt.subplots()
res = None
for var in plot_data_raw.keys():
d_var = plot_data_raw[var][1]
res = d_var if res is None else np.concatenate((res, d_var), axis=-1)
if raw is True:
for i in range(res.shape[1]):
ax.plot(self.f_index, res[:, i], "lightblue")
ax.plot(self.f_index, res.mean(axis=1), "blue")
else:
ma = pd.DataFrame(np.vstack(res)).rolling(5, center=True, axis=0)
mean = ma.mean().mean(axis=1).values.flatten()
upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
ax.plot(self.f_index, mean, "blue")
ax.fill_between(self.f_index, lower, upper, color="lightblue")
self._format_figure(ax, "total")
pdf_pages.savefig()
# close all open figures / plots
pdf_pages.close()
plt.close('all')
def _plot_difference(self, label_names):
plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter.pdf"
plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
logging.info(f"... plotting {plot_name}")
pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
colors = ["blue", "red", "green", "orange", "purple", "black", "grey"]
label_names = ["orig"] + label_names
max_iter = len(self.plot_data)
var_keys = self.plot_data[0].keys()
for var in var_keys:
fig, ax = plt.subplots()
for i in reversed(range(max_iter)):
plot_data = self.plot_data[i]
c = colors[i]
ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
mean = ma.mean().mean(axis=1).values.flatten()
ax.plot(self.f_index, mean, c, label=label_names[i])
if i < 1:
upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5, label=None)
self._format_figure(ax, var)
ax.legend(loc="upper center", ncol=max_iter)
pdf_pages.savefig()
# close all open figures / plots
pdf_pages.close()
plt.close('all')
def f_proc(var, d_var):
var_str = str(var)
t = (d_var.datetime - d_var.datetime[0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1).autopower()
return var_str, f, pgram
def f_proc_2(g, m, pos, variables_dim, time_dim):
raw_data_single = dict()
if m == 0:
d = g.id_class._data
else:
gd = g.id_class
filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]}
d = (gd.input_data.sel(filter_sel), gd.target_data)
d = d[pos] if isinstance(d, tuple) else d
for var in d[variables_dim].values:
d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim)
var_str, f, pgram = f_proc(var, d_var)
raw_data_single[var_str] = [(f, pgram)]
return raw_data_single
......@@ -9,10 +9,7 @@ import warnings
from typing import Dict, List, Tuple
import matplotlib
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import pandas as pd
import seaborn as sns
......@@ -22,6 +19,7 @@ from matplotlib.backends.backend_pdf import PdfPages
from mlair import helpers
from mlair.data_handler.iterator import DataCollection
from mlair.helpers import TimeTrackingWrapper
from mlair.plotting.abstract_plot_class import AbstractPlotClass
logging.getLogger('matplotlib').setLevel(logging.WARNING)
......@@ -31,100 +29,6 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
# import matplotlib.pyplot as plt
class AbstractPlotClass:
"""
Abstract class for all plotting routines to unify plot workflow.
Each inheritance requires a _plot method. Create a plot class like:
.. code-block:: python
class MyCustomPlot(AbstractPlotClass):
def __init__(self, plot_folder, *args, **kwargs):
super().__init__(plot_folder, "custom_plot_name")
self._data = self._prepare_data(*args, **kwargs)
self._plot(*args, **kwargs)
self._save()
def _prepare_data(*args, **kwargs):
<your custom data preparation>
return data
def _plot(*args, **kwargs):
<your custom plotting without saving>
The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are
using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be
set in super class initialisation).
Methods like the shown _prepare_data() are optional. The only method required to implement is _plot.
If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot
class. It will log the spent time if you call your plotting without saving the returned object.
.. code-block:: python
@TimeTrackingWrapper
class MyCustomPlot(AbstractPlotClass):
pass
Let's assume it takes a while to create this very special plot.
>>> MyCustomPlot()
INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss)
"""
def __init__(self, plot_folder, plot_name, resolution=500, rc_params=None):
"""Set up plot folder and name, and plot resolution (default 500dpi)."""
plot_folder = os.path.abspath(plot_folder)
if not os.path.exists(plot_folder):
os.makedirs(plot_folder)
self.plot_folder = plot_folder
self.plot_name = plot_name
self.resolution = resolution
if rc_params is None:
rc_params = {'axes.labelsize': 'large',
'xtick.labelsize': 'large',
'ytick.labelsize': 'large',
'legend.fontsize': 'large',
'axes.titlesize': 'large',
}
self.rc_params = rc_params
self._update_rc_params()
def _plot(self, *args):
"""Abstract plot class needs to be implemented in inheritance."""
raise NotImplementedError
def _save(self, **kwargs):
"""Store plot locally. Name of and path to plot need to be set on initialisation."""
plot_name = os.path.join(self.plot_folder, f"{self.plot_name}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=self.resolution, **kwargs)
plt.close('all')
def _update_rc_params(self):
plt.rcParams.update(self.rc_params)
@staticmethod
def _get_sampling(sampling):
if sampling == "daily":
return "D"
elif sampling == "hourly":
return "h"
@staticmethod
def get_dataset_colors():
"""
Standard colors used for train-, val-, and test-sets during postprocessing
"""
colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9", "train_val": "#000000"} # hex code
return colors
@TimeTrackingWrapper
class PlotMonthlySummary(AbstractPlotClass):
"""
......@@ -230,132 +134,6 @@ class PlotMonthlySummary(AbstractPlotClass):
plt.tight_layout()
@TimeTrackingWrapper
class PlotStationMap(AbstractPlotClass):
"""
Plot geographical overview of all used stations as squares.
Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to
plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored
topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf
.. image:: ../../../../../_source/_plots/station_map.png
:width: 400
"""
def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"):
"""
Set attributes and create plot.
:param generators: dictionary with the plot color of each data set as key and the generator containing all stations
as value.
:param plot_folder: path to save the plot (default: current directory)
"""
super().__init__(plot_folder, plot_name)
self._ax = None
self._gl = None
self._plot(generators)
self._save(bbox_inches="tight")
def _draw_background(self):
"""Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
import cartopy.feature as cfeature
self._ax.add_feature(cfeature.LAND.with_scale("50m"))
self._ax.natural_earth_shp(resolution='50m')
self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
def _plot_stations(self, generators):
"""
Loop over all keys in generators dict and its containing stations and plot the stations's position.
Position is highlighted by a square on the map regarding the given color.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
import cartopy.crs as ccrs
if generators is not None:
legend_elements = []
default_colors = self.get_dataset_colors()
for element in generators:
data_collection, plot_opts = self._get_collection_and_opts(element)
name = data_collection.name or "unknown"
marker = plot_opts.get("marker", "s")
ms = plot_opts.get("ms", 6)
mec = plot_opts.get("mec", "k")
mfc = plot_opts.get("mfc", default_colors.get(name, "b"))
legend_elements.append(
mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None',
label=f"{name} ({len(data_collection)})"))
for station in data_collection:
coords = station.get_coordinates()
IDx, IDy = coords["lon"], coords["lat"]
self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
if len(legend_elements) > 0:
self._ax.legend(handles=legend_elements, loc='best')
@staticmethod
def _adjust_marker(marker):
_adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"}
if isinstance(marker, int) and marker in _adjust.keys():
return _adjust[marker]
else:
return marker
@staticmethod
def _get_collection_and_opts(element):
if isinstance(element, tuple):
if len(element) == 1:
return element[0], {}
else:
return element
else:
return element, {}
def _plot(self, generators: List):
"""
Create the station map plot.
Set figure and call all required sub-methods.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
fig = plt.figure(figsize=(10, 5))
self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
self._gl.xformatter = LONGITUDE_FORMATTER
self._gl.yformatter = LATITUDE_FORMATTER
self._draw_background()
self._plot_stations(generators)
self._adjust_extent()
plt.tight_layout()
def _adjust_extent(self):
import cartopy.crs as ccrs
def diff(arr):
return arr[1] - arr[0], arr[3] - arr[2]
def find_ratio(delta, reference=5):
return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
extent = self._ax.get_extent(crs=ccrs.PlateCarree())
ratio = find_ratio(diff(extent))
new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
self._ax.set_extent(new_extent, crs=ccrs.PlateCarree())
@TimeTrackingWrapper
class PlotConditionalQuantiles(AbstractPlotClass):
"""
......@@ -1138,133 +916,6 @@ class PlotTimeSeries:
return matplotlib.backends.backend_pdf.PdfPages(plot_name)
@TimeTrackingWrapper
class PlotAvailability(AbstractPlotClass):
"""
Create data availablility plot similar to Gantt plot.
Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal
resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a
colored bar or a blank space.
You can set different colors to highlight subsets for example by providing different generators for the same index
using different keys in the input dictionary.
Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs
in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset.
Calling this class will create three versions fo the availability plot.
1) Data availability for each element
1) Data availability as summary over all elements (is there at least a single elemnt for each time step)
1) Combination of single and overall availability
.. image:: ../../../../../_source/_plots/data_availability.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_summary.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_combined.png
:width: 400
"""
def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
summary_name="data availability", time_dimension="datetime", window_dimension="window"):
"""Initialise."""
# create standard Gantt plot for all stations (currently in single pdf file with single page)
super().__init__(plot_folder, "data_availability")
self.time_dim = time_dimension
self.window_dim = window_dimension
self.sampling = self._get_sampling(sampling)
self.linewidth = None
if self.sampling == 'h':
self.linewidth = 0.001
plot_dict = self._prepare_data(generators)
lgd = self._plot(plot_dict)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
# create summary Gantt plot (is data in at least one station available)
self.plot_name += "_summary"
plot_dict_summary = self._summarise_data(generators, summary_name)
lgd = self._plot(plot_dict_summary)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
# combination of station and summary plot, last element is summary broken bar
self.plot_name = "data_availability_combined"
plot_dict_summary.update(plot_dict)
lgd = self._plot(plot_dict_summary)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
def _prepare_data(self, generators: Dict[str, DataCollection]):
plt_dict = {}
for subset, data_collection in generators.items():
for station in data_collection:
labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum()
plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
index=labels.coords[self.time_dim].values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(str(station)) is None:
plt_dict[str(station)] = {subset: t2}
else:
plt_dict[str(station)].update({subset: t2})
return plt_dict
def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
plt_dict = {}
for subset, data_collection in generators.items():
all_data = None
for station in data_collection:
labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
if all_data is None:
all_data = labels_bool
else:
tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords
all_data = np.logical_or(tmp, labels_bool).combine_first(
all_data) # apply logical on merge and fill missing with all_data
group = (all_data != all_data.shift({self.time_dim: 1})).cumsum()
plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
index=all_data.coords[self.time_dim].values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(summary_name) is None:
plt_dict[summary_name] = {subset: t2}
else:
plt_dict[summary_name].update({subset: t2})
return plt_dict
def _plot(self, plt_dict):
colors = self.get_dataset_colors()
_used_colors = []
pos = 0
height = 0.8 # should be <= 1
yticklabels = []
number_of_stations = len(plt_dict.keys())
fig, ax = plt.subplots(figsize=(10, number_of_stations / 3))
for station, d in sorted(plt_dict.items(), reverse=True):
pos += 1
for subset, color in colors.items():
plt_data = d.get(subset)
if plt_data is None:
continue
elif color not in _used_colors: # this is required for a proper legend creation
_used_colors.append(color)
ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth)
yticklabels.append(station)
ax.set_ylim([height, number_of_stations + 1])
ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2)
ax.set_yticklabels(yticklabels)
handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors]
lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
return lgd
@TimeTrackingWrapper
class PlotSeparationOfScales(AbstractPlotClass):
......@@ -1292,178 +943,6 @@ class PlotSeparationOfScales(AbstractPlotClass):
self._save()
@TimeTrackingWrapper
class PlotAvailabilityHistogram(AbstractPlotClass):
"""
Create data availability plots as histogram.
Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean).
Calling this class creates two different types of histograms where each generator
1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis)
2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number
of samples (yaxis)
.. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png
:width: 400
"""
def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".",
subset_dim: str = 'DataSet', history_dim: str = 'window',
station_dim: str = 'Stations',):
super().__init__(plot_folder, "data_availability_histogram")
self.subset_dim = subset_dim
self.history_dim = history_dim
self.station_dim = station_dim
self.freq = None
self.temporal_dim = None
self.target_dim = None
self._prepare_data(generators)
for plt_type in self.allowed_plot_types:
plot_name_tmp = self.plot_name
self.plot_name += '_' + plt_type
self._plot(plt_type=plt_type)
self._save()
self.plot_name = plot_name_tmp
def _set_dims_from_datahandler(self, data_handler):
self.temporal_dim = data_handler.id_class.time_dim
self.target_dim = data_handler.id_class.target_dim
self.freq = self._get_sampling(data_handler.id_class.sampling)
@property
def allowed_plot_types(self):
plot_types = ['hist', 'hist_cum']
return plot_types
def _prepare_data(self, generators: Dict[str, DataCollection]):
"""
Prepares data to be used by plot methods.
Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim
"""
avail_data_time_sum = {}
avail_data_station_sum = {}
dataset_time_interval = {}
for subset, generator in generators.items():
avail_list = []
for station in generator:
self._set_dims_from_datahandler(data_handler=station)
station_data_x = station.get_X(as_numpy=False)[0]
station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame
self.target_dim: station_data_x[self.target_dim].values[0]}]
station_data_x = self._reduce_dims(station_data_x)
avail_list.append(station_data_x.notnull())
avail_data = xr.concat(avail_list, dim=self.station_dim).notnull()
avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim)
avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim)
dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray(
avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict'
)
avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(),
name=self.subset_dim)
)
full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq)
self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(),
name=self.subset_dim))
self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index})
self.dataset_time_interval = dataset_time_interval
def _reduce_dims(self, dataset):
if len(dataset.dims) > 2:
required = {self.temporal_dim, self.station_dim}
unimportant = set(dataset.dims).difference(required)
sel_dict = {un: dataset[un].values[0] for un in unimportant}
dataset = dataset.loc[sel_dict]
return dataset
@staticmethod
def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'):
if isinstance(xarray, xr.DataArray):
first = xarray.coords[dim_name].values[0]
last = xarray.coords[dim_name].values[-1]
if return_type == 'as_tuple':
return first, last
elif return_type == 'as_dict':
return {'first': first, 'last': last}
else:
raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'")
else:
raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}")
@staticmethod
def _make_full_time_index(irregular_time_index, freq):
full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq)
return full_time_index
def _plot(self, plt_type='hist', *args):
if plt_type == 'hist':
self._plot_hist()
elif plt_type == 'hist_cum':
self._plot_hist_cum()
else:
raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}")
def _plot_hist(self, *args):
colors = self.get_dataset_colors()
fig, axes = plt.subplots(figsize=(10, 3))
for i, subset in enumerate(self.dataset_time_interval.keys()):
plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset,
self.temporal_dim: slice(
self.dataset_time_interval[subset]['first'],
self.dataset_time_interval[subset]['last']
)
}
)
plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset)
plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset])
lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
facecolor='white', framealpha=1, edgecolor='black')
for lgd_line in lgd.get_lines():
lgd_line.set_linewidth(4.0)
plt.gca().xaxis.set_major_locator(mdates.YearLocator())
plt.title('')
plt.ylabel('Number of samples')
plt.tight_layout()
def _plot_hist_cum(self, *args):
colors = self.get_dataset_colors()
fig, axes = plt.subplots(figsize=(10, 3))
n_bins = int(self.avail_data_cum_sum.max().values)
bins = np.arange(0, n_bins+1)
descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby(
self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False
).coords[self.subset_dim].values
for subset in descending_subsets:
self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes,
bins=bins,
label=subset,
cumulative=-1,
color=colors[subset],
# alpha=.5
)
lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
facecolor='white', framealpha=1, edgecolor='black')
plt.title('')
plt.ylabel('Number of stations')
plt.xlabel('Number of samples')
plt.xlim((bins[0], bins[-1]))
plt.tight_layout()
if __name__ == "__main__":
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
path = "../../testrun_network/forecasts"
......
......@@ -19,9 +19,10 @@ from mlair.helpers.datastore import NameNotFoundInDataStore
from mlair.helpers import TimeTracking, statistics, extract_value, remove_items, to_list, tables
from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel
from mlair.model_modules import AbstractModelClass
from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \
PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotAvailabilityHistogram, \
PlotConditionalQuantiles, PlotSeparationOfScales
from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \
PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotConditionalQuantiles, PlotSeparationOfScales
from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \
PlotPeriodogram
from mlair.run_modules.run_environment import RunEnvironment
......@@ -296,6 +297,7 @@ class PostProcessing(RunEnvironment):
"""
logging.info("Run plotting routines...")
path = self.data_store.get("forecast_path")
use_multiprocessing = self.data_store.get("use_multiprocessing")
plot_list = self.data_store.get("plot_list", "postprocessing")
time_dim = self.data_store.get("time_dim")
......@@ -325,23 +327,6 @@ class PostProcessing(RunEnvironment):
except Exception as e:
logging.error(f"Could not create plot PlotConditionalQuantiles due to the following error: {e}")
try:
if "PlotStationMap" in plot_list:
if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get(
"hostname")[:6] in self.data_store.get("hpc_hosts"):
logging.warning(
f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}")
else:
gens = [(self.train_data, {"marker": 5, "ms": 9}),
(self.val_data, {"marker": 6, "ms": 9}),
(self.test_data, {"marker": 4, "ms": 9})]
PlotStationMap(generators=gens, plot_folder=self.plot_path)
gens = [(self.train_val_data, {"marker": 8, "ms": 9}),
(self.test_data, {"marker": 9, "ms": 9})]
PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var")
except Exception as e:
logging.error(f"Could not create plot PlotStationMap due to the following error: {e}")
try:
if "PlotMonthlySummary" in plot_list:
PlotMonthlySummary(self.test_data.keys(), path, r"forecasts_%s_test.nc", self.target_var,
......@@ -372,6 +357,23 @@ class PostProcessing(RunEnvironment):
except Exception as e:
logging.error(f"Could not create plot PlotTimeSeries due to the following error: {e}")
try:
if "PlotStationMap" in plot_list:
if self.data_store.get("hostname")[:2] in self.data_store.get("hpc_hosts") or self.data_store.get(
"hostname")[:6] in self.data_store.get("hpc_hosts"):
logging.warning(
f"Skip 'PlotStationMap` because running on a hpc node: {self.data_store.get('hostname')}")
else:
gens = [(self.train_data, {"marker": 5, "ms": 9}),
(self.val_data, {"marker": 6, "ms": 9}),
(self.test_data, {"marker": 4, "ms": 9})]
PlotStationMap(generators=gens, plot_folder=self.plot_path)
gens = [(self.train_val_data, {"marker": 8, "ms": 9}),
(self.test_data, {"marker": 9, "ms": 9})]
PlotStationMap(generators=gens, plot_folder=self.plot_path, plot_name="station_map_var")
except Exception as e:
logging.error(f"Could not create plot PlotStationMap due to the following error: {e}")
try:
if "PlotAvailability" in plot_list:
avail_data = {"train": self.train_data, "val": self.val_data, "test": self.test_data}
......@@ -388,6 +390,14 @@ class PostProcessing(RunEnvironment):
except Exception as e:
logging.error(f"Could not create plot PlotAvailabilityHistogram due to the following error: {e}")
try:
if "PlotPeriodogram" in plot_list:
PlotPeriodogram(self.train_data, plot_folder=self.plot_path, time_dim=time_dim,
variables_dim=target_dim, sampling=self._sampling,
use_multiprocessing=use_multiprocessing)
except Exception as e:
logging.error(f"Could not create plot PlotPeriodogram due to the following error: {e}")
def calculate_test_score(self):
"""Evaluate test score of model and save locally."""
......
......@@ -10,7 +10,6 @@ import multiprocessing
import requests
import psutil
import numpy as np
import pandas as pd
from mlair.data_handler import DataCollection, AbstractDataHandler
......@@ -257,6 +256,7 @@ class PreProcessing(RunEnvironment):
if dh is not None:
collection.add(dh)
valid_stations.append(s)
pool.close()
else: # serial solution
logging.info("use serial validate station approach")
for station in set_stations:
......
absl-py==0.11.0
appdirs==1.4.4
astor==0.8.1
astropy==4.1
attrs==20.3.0
bottleneck==1.3.2
cached-property==1.5.2
......
absl-py==0.11.0
appdirs==1.4.4
astor==0.8.1
astropy==4.1
attrs==20.3.0
bottleneck==1.3.2
cached-property==1.5.2
......