Skip to content
Snippets Groups Projects

include plot script restructuring

Merged Ghost User requested to merge develop into lukas_issue299_feat_histogram_plots
1 file
+ 1
0
Compare changes
  • Side-by-side
  • Inline
+ 721
0
"""Collection of plots to get more insight into data."""
__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2021-04-13'
from typing import List, Dict
import os
import logging
import multiprocessing
import psutil
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib
from matplotlib import lines as mlines, pyplot as plt, patches as mpatches, dates as mdates
from astropy.timeseries import LombScargle
from mlair.data_handler import DataCollection
from mlair.helpers import TimeTrackingWrapper, to_list
from mlair.plotting.abstract_plot_class import AbstractPlotClass
@TimeTrackingWrapper
class PlotStationMap(AbstractPlotClass): # pragma: no cover
"""
Plot geographical overview of all used stations as squares.
Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to
plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored
topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf
.. image:: ../../../../../_source/_plots/station_map.png
:width: 400
"""
def __init__(self, generators: List, plot_folder: str = ".", plot_name="station_map"):
"""
Set attributes and create plot.
:param generators: dictionary with the plot color of each data set as key and the generator containing all stations
as value.
:param plot_folder: path to save the plot (default: current directory)
"""
super().__init__(plot_folder, plot_name)
self._ax = None
self._gl = None
self._plot(generators)
self._save(bbox_inches="tight")
def _draw_background(self):
"""Draw coastline, lakes, ocean, rivers and country borders as background on the map."""
import cartopy.feature as cfeature
self._ax.add_feature(cfeature.LAND.with_scale("50m"))
self._ax.natural_earth_shp(resolution='50m')
self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
def _plot_stations(self, generators):
"""
Loop over all keys in generators dict and its containing stations and plot the stations's position.
Position is highlighted by a square on the map regarding the given color.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
import cartopy.crs as ccrs
if generators is not None:
legend_elements = []
default_colors = self.get_dataset_colors()
for element in generators:
data_collection, plot_opts = self._get_collection_and_opts(element)
name = data_collection.name or "unknown"
marker = plot_opts.get("marker", "s")
ms = plot_opts.get("ms", 6)
mec = plot_opts.get("mec", "k")
mfc = plot_opts.get("mfc", default_colors.get(name, "b"))
legend_elements.append(
mlines.Line2D([], [], mfc=mfc, mec=mec, marker=self._adjust_marker(marker), ms=ms, linestyle='None',
label=f"{name} ({len(data_collection)})"))
for station in data_collection:
coords = station.get_coordinates()
IDx, IDy = coords["lon"], coords["lat"]
self._ax.plot(IDx, IDy, mfc=mfc, mec=mec, marker=marker, ms=ms, transform=ccrs.PlateCarree())
if len(legend_elements) > 0:
self._ax.legend(handles=legend_elements, loc='best')
@staticmethod
def _adjust_marker(marker):
_adjust = {4: "<", 5: ">", 6: "^", 7: "v", 8: "<", 9: ">", 10: "^", 11: "v"}
if isinstance(marker, int) and marker in _adjust.keys():
return _adjust[marker]
else:
return marker
@staticmethod
def _get_collection_and_opts(element):
if isinstance(element, tuple):
if len(element) == 1:
return element[0], {}
else:
return element
else:
return element, {}
def _plot(self, generators: List):
"""
Create the station map plot.
Set figure and call all required sub-methods.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
fig = plt.figure(figsize=(10, 5))
self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
self._gl = self._ax.gridlines(xlocs=range(0, 21, 5), ylocs=range(44, 59, 2), draw_labels=True)
self._gl.xformatter = LONGITUDE_FORMATTER
self._gl.yformatter = LATITUDE_FORMATTER
self._draw_background()
self._plot_stations(generators)
self._adjust_extent()
plt.tight_layout()
def _adjust_extent(self):
import cartopy.crs as ccrs
def diff(arr):
return arr[1] - arr[0], arr[3] - arr[2]
def find_ratio(delta, reference=5):
return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
extent = self._ax.get_extent(crs=ccrs.PlateCarree())
ratio = find_ratio(diff(extent))
new_extent = extent + np.array([-1, 1, -1, 1]) * ratio
self._ax.set_extent(new_extent, crs=ccrs.PlateCarree())
@TimeTrackingWrapper
class PlotAvailability(AbstractPlotClass): # pragma: no cover
"""
Create data availablility plot similar to Gantt plot.
Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal
resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a
colored bar or a blank space.
You can set different colors to highlight subsets for example by providing different generators for the same index
using different keys in the input dictionary.
Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs
in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset.
Calling this class will create three versions fo the availability plot.
1) Data availability for each element
1) Data availability as summary over all elements (is there at least a single elemnt for each time step)
1) Combination of single and overall availability
.. image:: ../../../../../_source/_plots/data_availability.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_summary.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_combined.png
:width: 400
"""
def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".", sampling="daily",
summary_name="data availability", time_dimension="datetime", window_dimension="window"):
"""Initialise."""
# create standard Gantt plot for all stations (currently in single pdf file with single page)
super().__init__(plot_folder, "data_availability")
self.time_dim = time_dimension
self.window_dim = window_dimension
self.sampling = self._get_sampling(sampling)
self.linewidth = None
if self.sampling == 'h':
self.linewidth = 0.001
plot_dict = self._prepare_data(generators)
lgd = self._plot(plot_dict)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
# create summary Gantt plot (is data in at least one station available)
self.plot_name += "_summary"
plot_dict_summary = self._summarise_data(generators, summary_name)
lgd = self._plot(plot_dict_summary)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
# combination of station and summary plot, last element is summary broken bar
self.plot_name = "data_availability_combined"
plot_dict_summary.update(plot_dict)
lgd = self._plot(plot_dict_summary)
self._save(bbox_extra_artists=(lgd,), bbox_inches="tight")
def _prepare_data(self, generators: Dict[str, DataCollection]):
plt_dict = {}
for subset, data_collection in generators.items():
for station in data_collection:
labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
group = (labels_bool != labels_bool.shift({self.time_dim: 1})).cumsum()
plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values},
index=labels.coords[self.time_dim].values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(str(station)) is None:
plt_dict[str(station)] = {subset: t2}
else:
plt_dict[str(station)].update({subset: t2})
return plt_dict
def _summarise_data(self, generators: Dict[str, DataCollection], summary_name: str):
plt_dict = {}
for subset, data_collection in generators.items():
all_data = None
for station in data_collection:
labels = station.get_Y(as_numpy=False).resample({self.time_dim: self.sampling}, skipna=True).mean()
labels_bool = labels.sel(**{self.window_dim: 1}).notnull()
if all_data is None:
all_data = labels_bool
else:
tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords
all_data = np.logical_or(tmp, labels_bool).combine_first(
all_data) # apply logical on merge and fill missing with all_data
group = (all_data != all_data.shift({self.time_dim: 1})).cumsum()
plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values},
index=all_data.coords[self.time_dim].values)
t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0]))
t2 = [i[1:] for i in t if i[0]]
if plt_dict.get(summary_name) is None:
plt_dict[summary_name] = {subset: t2}
else:
plt_dict[summary_name].update({subset: t2})
return plt_dict
def _plot(self, plt_dict):
colors = self.get_dataset_colors()
_used_colors = []
pos = 0
height = 0.8 # should be <= 1
yticklabels = []
number_of_stations = len(plt_dict.keys())
fig, ax = plt.subplots(figsize=(10, number_of_stations / 3))
for station, d in sorted(plt_dict.items(), reverse=True):
pos += 1
for subset, color in colors.items():
plt_data = d.get(subset)
if plt_data is None:
continue
elif color not in _used_colors: # this is required for a proper legend creation
_used_colors.append(color)
ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth)
yticklabels.append(station)
ax.set_ylim([height, number_of_stations + 1])
ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2)
ax.set_yticklabels(yticklabels)
handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors]
lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles))
return lgd
@TimeTrackingWrapper
class PlotAvailabilityHistogram(AbstractPlotClass): # pragma: no cover
"""
Create data availability plots as histogram.
Each entry of each generator is checked for `notnull()` values along all the datetime axis (boolean).
Calling this class creates two different types of histograms where each generator
1) data_availability_histogram: datetime (xaxis) vs. number of stations with availabile data (yaxis)
2) data_availability_histogram_cumulative: number of samples (xaxis) vs. number of stations having at least number
of samples (yaxis)
.. image:: ../../../../../_source/_plots/data_availability_histogram_hist.png
:width: 400
.. image:: ../../../../../_source/_plots/data_availability_histogram_hist_cum.png
:width: 400
"""
def __init__(self, generators: Dict[str, DataCollection], plot_folder: str = ".",
subset_dim: str = 'DataSet', history_dim: str = 'window',
station_dim: str = 'Stations', ):
super().__init__(plot_folder, "data_availability_histogram")
self.subset_dim = subset_dim
self.history_dim = history_dim
self.station_dim = station_dim
self.freq = None
self.temporal_dim = None
self.target_dim = None
self._prepare_data(generators)
for plt_type in self.allowed_plot_types:
plot_name_tmp = self.plot_name
self.plot_name += '_' + plt_type
self._plot(plt_type=plt_type)
self._save()
self.plot_name = plot_name_tmp
def _set_dims_from_datahandler(self, data_handler):
self.temporal_dim = data_handler.id_class.time_dim
self.target_dim = data_handler.id_class.target_dim
self.freq = self._get_sampling(data_handler.id_class.sampling)
@property
def allowed_plot_types(self):
plot_types = ['hist', 'hist_cum']
return plot_types
def _prepare_data(self, generators: Dict[str, DataCollection]):
"""
Prepares data to be used by plot methods.
Creates xarrays which are sums of valid data (boolean sums) across i) station_dim and ii) temporal_dim
"""
avail_data_time_sum = {}
avail_data_station_sum = {}
dataset_time_interval = {}
for subset, generator in generators.items():
avail_list = []
for station in generator:
self._set_dims_from_datahandler(data_handler=station)
station_data_x = station.get_X(as_numpy=False)[0]
station_data_x = station_data_x.loc[{self.history_dim: 0, # select recent window frame
self.target_dim: station_data_x[self.target_dim].values[0]}]
station_data_x = self._reduce_dims(station_data_x)
avail_list.append(station_data_x.notnull())
avail_data = xr.concat(avail_list, dim=self.station_dim).notnull()
avail_data_time_sum[subset] = avail_data.sum(dim=self.station_dim)
avail_data_station_sum[subset] = avail_data.sum(dim=self.temporal_dim)
dataset_time_interval[subset] = self._get_first_and_last_indexelement_from_xarray(
avail_data_time_sum[subset], dim_name=self.temporal_dim, return_type='as_dict'
)
avail_data_amount = xr.concat(avail_data_time_sum.values(), pd.Index(avail_data_time_sum.keys(),
name=self.subset_dim)
)
full_time_index = self._make_full_time_index(avail_data_amount.coords[self.temporal_dim].values, freq=self.freq)
self.avail_data_cum_sum = xr.concat(avail_data_station_sum.values(), pd.Index(avail_data_station_sum.keys(),
name=self.subset_dim))
self.avail_data_amount = avail_data_amount.reindex({self.temporal_dim: full_time_index})
self.dataset_time_interval = dataset_time_interval
def _reduce_dims(self, dataset):
if len(dataset.dims) > 2:
required = {self.temporal_dim, self.station_dim}
unimportant = set(dataset.dims).difference(required)
sel_dict = {un: dataset[un].values[0] for un in unimportant}
dataset = dataset.loc[sel_dict]
return dataset
@staticmethod
def _get_first_and_last_indexelement_from_xarray(xarray, dim_name, return_type='as_tuple'):
if isinstance(xarray, xr.DataArray):
first = xarray.coords[dim_name].values[0]
last = xarray.coords[dim_name].values[-1]
if return_type == 'as_tuple':
return first, last
elif return_type == 'as_dict':
return {'first': first, 'last': last}
else:
raise TypeError(f"return_type must be 'as_tuple' or 'as_dict', but is '{return_type}'")
else:
raise TypeError(f"xarray must be of type xr.DataArray, but is of type {type(xarray)}")
@staticmethod
def _make_full_time_index(irregular_time_index, freq):
full_time_index = pd.date_range(start=irregular_time_index[0], end=irregular_time_index[-1], freq=freq)
return full_time_index
def _plot(self, plt_type='hist', *args):
if plt_type == 'hist':
self._plot_hist()
elif plt_type == 'hist_cum':
self._plot_hist_cum()
else:
raise ValueError(f"plt_type mus be 'hist' or 'hist_cum', but is {type}")
def _plot_hist(self, *args):
colors = self.get_dataset_colors()
fig, axes = plt.subplots(figsize=(10, 3))
for i, subset in enumerate(self.dataset_time_interval.keys()):
plot_dataset = self.avail_data_amount.sel({self.subset_dim: subset,
self.temporal_dim: slice(
self.dataset_time_interval[subset]['first'],
self.dataset_time_interval[subset]['last']
)
}
)
plot_dataset.plot.step(color=colors[subset], ax=axes, label=subset)
plt.fill_between(plot_dataset.coords[self.temporal_dim].values, plot_dataset.values, color=colors[subset])
lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
facecolor='white', framealpha=1, edgecolor='black')
for lgd_line in lgd.get_lines():
lgd_line.set_linewidth(4.0)
plt.gca().xaxis.set_major_locator(mdates.YearLocator())
plt.title('')
plt.ylabel('Number of samples')
plt.tight_layout()
def _plot_hist_cum(self, *args):
colors = self.get_dataset_colors()
fig, axes = plt.subplots(figsize=(10, 3))
n_bins = int(self.avail_data_cum_sum.max().values)
bins = np.arange(0, n_bins + 1)
descending_subsets = self.avail_data_cum_sum.max(dim=self.station_dim).sortby(
self.avail_data_cum_sum.max(dim=self.station_dim), ascending=False
).coords[self.subset_dim].values
for subset in descending_subsets:
self.avail_data_cum_sum.sel({self.subset_dim: subset}).plot.hist(ax=axes,
bins=bins,
label=subset,
cumulative=-1,
color=colors[subset],
# alpha=.5
)
lgd = fig.legend(loc="upper right", ncol=len(self.dataset_time_interval),
facecolor='white', framealpha=1, edgecolor='black')
plt.title('')
plt.ylabel('Number of stations')
plt.xlabel('Number of samples')
plt.xlim((bins[0], bins[-1]))
plt.tight_layout()
class PlotPeriodogram(AbstractPlotClass): # pragma: no cover
"""
Create Lomb-Scargle periodogram in raw input and target data. The Lomb-Scargle version can deal with missing values.
This plot routine is creating the following plots:
* "raw": data is not aggregated, 1 graph per variable
* "": single data lines are aggregated, 1 graph per variable
* "total": data is aggregated on all variables, single graph
If data consists on different sampling rates, a separate plot is create for each sampling.
.. image:: ../../../../../_source/_plots/periodogram.png
:width: 400
.. note::
This plot is not included in the default plot list. To use this plot, add "PlotPeriodogram" to the `plot_list`.
.. warning::
This plot is highly sensitive to the data handler structure. Therefore, it is highly likely that this method is
not compatible with any custom data handler. Proven data handlers are `DefaultDataHandler`,
`DataHandlerMixedSampling`, `DataHandlerMixedSamplingWithFilter`. To work properly, the data handler must have
the attribute `.id_class._data`.
"""
def __init__(self, generator: Dict[str, DataCollection], plot_folder: str = ".", plot_name="periodogram",
variables_dim="variables", time_dim="datetime", sampling="daily", use_multiprocessing=False):
super().__init__(plot_folder, plot_name)
self.variables_dim = variables_dim
self.time_dim = time_dim
for pos, s in enumerate(sampling if isinstance(sampling, tuple) else (sampling, sampling)):
self._sampling = s
self._add_text = {0: "input", 1: "target"}[pos]
multiple, label_names = self._has_filter_dimension(generator[0], pos)
self._prepare_pgram(generator, pos, multiple, use_multiprocessing=use_multiprocessing)
self._plot(raw=True)
self._plot(raw=False)
self._plot_total(raw=True)
self._plot_total(raw=False)
if multiple > 1:
self._plot_difference(label_names)
@staticmethod
def _has_filter_dimension(g, pos):
# check if coords raw data differs from input / target data
check_data = g.id_class
if "filter" not in [check_data.input_data, check_data.target_data][pos].coords.dims:
return 1, []
else:
if len(set(check_data._data[0].coords.dims).symmetric_difference(check_data.input_data.coords.dims)) > 0:
return g.id_class.input_data.coords["filter"].shape[0], g.id_class.input_data.coords[
"filter"].values.tolist()
else:
return 1, []
@TimeTrackingWrapper
def _prepare_pgram(self, generator, pos, multiple=1, use_multiprocessing=False):
"""
Create periodogram data.
"""
self.raw_data = []
self.plot_data = []
self.plot_data_raw = []
self.plot_data_mean = []
iter = range(multiple if multiple == 1 else multiple + 1)
for m in iter:
plot_data_single = dict()
plot_data_raw_single = dict()
plot_data_mean_single = dict()
raw_data_single = self._prepare_pgram_parallel_gen(generator, m, pos, use_multiprocessing)
# raw_data_single = self._prepare_pgram_parallel_var(generator, m, pos, use_multiprocessing)
self.f_index = np.logspace(-3, 0 if self._sampling == "daily" else np.log10(24), 1000)
for var in raw_data_single.keys():
pgram_com = []
pgram_mean = 0
all_data = raw_data_single[var]
pgram_mean_raw = np.zeros((len(self.f_index), len(all_data)))
for i, (f, pgram) in enumerate(all_data):
d = np.interp(self.f_index, f, pgram)
pgram_com.append(d)
pgram_mean += d
pgram_mean_raw[:, i] = d
pgram_mean /= len(all_data)
plot_data_single[var] = pgram_com
plot_data_mean_single[var] = (self.f_index, pgram_mean)
plot_data_raw_single[var] = (self.f_index, pgram_mean_raw)
self.plot_data.append(plot_data_single)
self.plot_data_mean.append(plot_data_mean_single)
self.plot_data_raw.append(plot_data_raw_single)
def _prepare_pgram_parallel_var(self, generator, m, pos, use_multiprocessing):
"""Implementation of data preprocessing using parallel variables element processing."""
raw_data_single = dict()
for g in generator:
if m == 0:
d = g.id_class._data
else:
gd = g.id_class
filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]}
d = (gd.input_data.sel(filter_sel), gd.target_data)
d = d[pos] if isinstance(d, tuple) else d
res = []
if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution
pool = multiprocessing.Pool(
min([psutil.cpu_count(logical=False), len(d[self.variables_dim].values),
16])) # use only physical cpus
output = [
pool.apply_async(f_proc,
args=(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
for var in d[self.variables_dim].values]
for i, p in enumerate(output):
res.append(p.get())
pool.close()
else: # serial solution
for var in d[self.variables_dim].values:
res.append(f_proc(var, d.loc[{self.variables_dim: var}].squeeze().dropna(self.time_dim)))
for (var_str, f, pgram) in res:
if var_str not in raw_data_single.keys():
raw_data_single[var_str] = [(f, pgram)]
else:
raw_data_single[var_str] = raw_data_single[var_str] + [(f, pgram)]
return raw_data_single
def _prepare_pgram_parallel_gen(self, generator, m, pos, use_multiprocessing):
"""Implementation of data preprocessing using parallel generator element processing."""
raw_data_single = dict()
res = []
if multiprocessing.cpu_count() > 1 and use_multiprocessing: # parallel solution
pool = multiprocessing.Pool(
min([psutil.cpu_count(logical=False), len(generator), 16])) # use only physical cpus
output = [
pool.apply_async(f_proc_2, args=(g, m, pos, self.variables_dim, self.time_dim))
for g in generator]
for i, p in enumerate(output):
res.append(p.get())
pool.close()
else:
for g in generator:
res.append(f_proc_2(g, m, pos, self.variables_dim, self.time_dim))
for res_dict in res:
for k, v in res_dict.items():
if k not in raw_data_single.keys():
raw_data_single[k] = v
else:
raw_data_single[k] = raw_data_single[k] + v
return raw_data_single
@staticmethod
def _add_annotation_line(ax, pos, div, lims, unit):
for p in to_list(pos): # per year
ax.vlines(p / div, *lims, "black")
ax.text(p / div, lims[0], r"%s$%s^{-1}$" % (p, unit), rotation="vertical", rotation_mode="anchor")
def _format_figure(self, ax, var_name="total"):
"""
Set log scale on both axis, add labels and annotation lines, and set title.
:param ax: current ax object
:param var_name: name of variable that will be included in the title
"""
ax.set_yscale('log')
ax.set_xscale('log')
ax.set_ylabel("power", fontsize='x-large')
ax.set_xlabel("frequency $[day^{-1}$]", fontsize='x-large')
lims = ax.get_ylim()
self._add_annotation_line(ax, [1, 2, 3], 365.25, lims, "yr") # per year
self._add_annotation_line(ax, 1, 365.25 / 12, lims, "m") # per month
self._add_annotation_line(ax, 1, 7, lims, "w") # per week
self._add_annotation_line(ax, [1, 0.5], 1, lims, "d") # per day
if self._sampling == "hourly":
self._add_annotation_line(ax, 2, 1, lims, "d") # per day
self._add_annotation_line(ax, [1, 0.5], 1 / 24., lims, "h") # per hour
title = f"Periodogram ({var_name})"
ax.set_title(title)
def _plot(self, raw=True):
plot_path = os.path.join(os.path.abspath(self.plot_folder),
f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}.pdf")
pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
plot_data = self.plot_data[0]
plot_data_mean = self.plot_data_mean[0]
for var in plot_data.keys():
fig, ax = plt.subplots()
if raw is True:
for pgram in plot_data[var]:
ax.plot(self.f_index, pgram, "lightblue")
ax.plot(*plot_data_mean[var], "blue")
else:
ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
mean = ma.mean().mean(axis=1).values.flatten()
upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
ax.plot(self.f_index, mean, "blue")
ax.fill_between(self.f_index, lower, upper, color="lightblue")
self._format_figure(ax, var)
pdf_pages.savefig()
# close all open figures / plots
pdf_pages.close()
plt.close('all')
def _plot_total(self, raw=True):
plot_path = os.path.join(os.path.abspath(self.plot_folder),
f"{self.plot_name}{'_raw' if raw else ''}_{self._sampling}_{self._add_text}_total.pdf")
pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
plot_data_raw = self.plot_data_raw[0]
fig, ax = plt.subplots()
res = None
for var in plot_data_raw.keys():
d_var = plot_data_raw[var][1]
res = d_var if res is None else np.concatenate((res, d_var), axis=-1)
if raw is True:
for i in range(res.shape[1]):
ax.plot(self.f_index, res[:, i], "lightblue")
ax.plot(self.f_index, res.mean(axis=1), "blue")
else:
ma = pd.DataFrame(np.vstack(res)).rolling(5, center=True, axis=0)
mean = ma.mean().mean(axis=1).values.flatten()
upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
ax.plot(self.f_index, mean, "blue")
ax.fill_between(self.f_index, lower, upper, color="lightblue")
self._format_figure(ax, "total")
pdf_pages.savefig()
# close all open figures / plots
pdf_pages.close()
plt.close('all')
def _plot_difference(self, label_names):
plot_name = f"{self.plot_name}_{self._sampling}_{self._add_text}_filter.pdf"
plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name)
logging.info(f"... plotting {plot_name}")
pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
colors = ["blue", "red", "green", "orange", "purple", "black", "grey"]
label_names = ["orig"] + label_names
max_iter = len(self.plot_data)
var_keys = self.plot_data[0].keys()
for var in var_keys:
fig, ax = plt.subplots()
for i in reversed(range(max_iter)):
plot_data = self.plot_data[i]
c = colors[i]
ma = pd.DataFrame(np.vstack(plot_data[var]).T).rolling(5, center=True, axis=0)
mean = ma.mean().mean(axis=1).values.flatten()
ax.plot(self.f_index, mean, c, label=label_names[i])
if i < 1:
upper, lower = ma.max().mean(axis=1).values.flatten(), ma.min().mean(axis=1).values.flatten()
ax.fill_between(self.f_index, lower, upper, color="light" + c, alpha=0.5, label=None)
self._format_figure(ax, var)
ax.legend(loc="upper center", ncol=max_iter)
pdf_pages.savefig()
# close all open figures / plots
pdf_pages.close()
plt.close('all')
def f_proc(var, d_var):
var_str = str(var)
t = (d_var.datetime - d_var.datetime[0]).astype("timedelta64[h]").values / np.timedelta64(1, "D")
f, pgram = LombScargle(t, d_var.values.flatten(), nterms=1).autopower()
return var_str, f, pgram
def f_proc_2(g, m, pos, variables_dim, time_dim):
raw_data_single = dict()
if m == 0:
d = g.id_class._data
else:
gd = g.id_class
filter_sel = {"filter": gd.input_data.coords["filter"][m - 1]}
d = (gd.input_data.sel(filter_sel), gd.target_data)
d = d[pos] if isinstance(d, tuple) else d
for var in d[variables_dim].values:
d_var = d.loc[{variables_dim: var}].squeeze().dropna(time_dim)
var_str, f, pgram = f_proc(var, d_var)
raw_data_single[var_str] = [(f, pgram)]
return raw_data_single
Loading