Select Git revision
_Plot_cell_rho.py
Forked from
JuPedSim / JPSreport
Source project has a limited visibility.
postprocessing_plotting.py 32.14 KiB
__author__ = "Lukas Leufen, Felix Kleinert"
__date__ = '2019-12-17'
import logging
import math
import os
import warnings
from typing import Dict, List, Tuple
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from matplotlib.backends.backend_pdf import PdfPages
from src import helpers
from src.helpers import TimeTracking
from src.run_modules.run_environment import RunEnvironment
logging.getLogger('matplotlib').setLevel(logging.WARNING)
class PlotMonthlySummary(RunEnvironment):
"""
Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. The plot is saved
in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution.
"""
def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None,
plot_folder: str = "."):
"""
Sets attributes and create plot
:param stations: all stations to plot
:param data_path: path, where the data is located
:param name: full name of the local files with a % as placeholder for the station name
:param target_var: display name of the target variable on plot's axis
:param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given
the maximum lead time from data is used. (default None -> use maximum lead time from data).
:param plot_folder: path to save the plot (default: current directory)
"""
super().__init__()
self._data_path = data_path
self._data_name = name
self._data = self._prepare_data(stations)
self._window_lead_time = self._get_window_lead_time(window_lead_time)
self._plot(target_var, plot_folder)
def _prepare_data(self, stations: List) -> xr.DataArray:
"""
Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN prediction
and the observation and group them into monthly bins (no aggregation, only sorting them).
:param stations: all stations to plot
:return: The entire data set, flagged with the corresponding month.
"""
forecasts = None
for station in stations:
logging.debug(f"... preprocess station {station}")
file_name = os.path.join(self._data_path, self._data_name % station)
data = xr.open_dataarray(file_name)
data_cnn = data.sel(type="CNN").squeeze()
if len(data_cnn.shape) > 1:
data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values]
data_obs = data.sel(type="obs", ahead=1).squeeze()
data_obs.coords["ahead"] = "obs"
data_concat = xr.concat([data_obs, data_cnn], dim="ahead")
data_concat = data_concat.drop("type")
data_concat.index.values = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1
data_concat = data_concat.clip(min=0)
forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat
return forecasts
def _get_window_lead_time(self, window_lead_time: int):
"""
Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from
data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number
of ahead dimensions in data is lower than the given lead time, data's lead time is used.
:param window_lead_time: lead time from arguments to validate
:return: validated lead time, comes either from given argument or from data itself
"""
ahead_steps = len(self._data.ahead)
if window_lead_time is None:
window_lead_time = ahead_steps
return min(ahead_steps, window_lead_time)
def _plot(self, target_var: str, plot_folder: str):
"""
Main plot function that creates a monthly grouped box plot over all stations but with separate boxes for each
lead time step.
:param target_var: display name of the target variable on plot's axis
:param plot_folder: path to save the plot
"""
data = self._data.to_dataset(name='values').to_dask_dataframe()
logging.debug("... start plotting")
color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex()
ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1., palette=color_palette,
flierprops={'marker': '.', 'markersize': 1}, showmeans=True,
meanprops={'markersize': 1, 'markeredgecolor': 'k'})
ax.set(xlabel='month', ylabel=f'{target_var}')
plt.tight_layout()
self._save(plot_folder)
@staticmethod
def _save(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
"""
plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf')
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
plt.close('all')
class PlotStationMap(RunEnvironment):
"""
Plot geographical overview of all used stations as squares. Different data sets can be colorised by its key in the
input dictionary generators. The key represents the color to plot on the map. Currently, there is only a white
background, but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is
saved under plot_path with the name station_map.pdf
"""
def __init__(self, generators: Dict, plot_folder: str = "."):
"""
Sets attributes and create plot
:param generators: dictionary with the plot color of each data set as key and the generator containing all stations
as value.
:param plot_folder: path to save the plot (default: current directory)
"""
super().__init__()
self._ax = None
self._plot(generators, plot_folder)
def _draw_background(self):
"""
Draw coastline, lakes, ocean, rivers and country borders as background on the map.
"""
self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black')
self._ax.add_feature(cfeature.LAKES.with_scale("50m"))
self._ax.add_feature(cfeature.OCEAN.with_scale("50m"))
self._ax.add_feature(cfeature.RIVERS.with_scale("50m"))
self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black')
def _plot_stations(self, generators):
"""
The actual plot function. Loops over all keys in generators dict and its containing stations and plots a square
and the stations's position on the map regarding the given color.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
"""
if generators is not None:
for color, gen in generators.items():
for k, v in enumerate(gen):
station_coords = gen.get_data_generator(k).meta.loc[['station_lon', 'station_lat']]
# station_names = gen.get_data_generator(k).meta.loc[['station_id']]
IDx, IDy = float(station_coords.loc['station_lon'].values), float(
station_coords.loc['station_lat'].values)
self._ax.plot(IDx, IDy, mfc=color, mec='k', marker='s', markersize=6, transform=ccrs.PlateCarree())
def _plot(self, generators: Dict, plot_folder: str):
"""
Main plot function to create the station map plot. Sets figure and calls all required sub-methods.
:param generators: dictionary with the plot color of each data set as key and the generator containing all
stations as value.
:param plot_folder: path to save the plot
"""
fig = plt.figure(figsize=(10, 5))
self._ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
self._ax.set_extent([0, 20, 42, 58], crs=ccrs.PlateCarree())
self._draw_background()
self._plot_stations(generators)
self._save(plot_folder)
@staticmethod
def _save(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
"""
plot_name = os.path.join(os.path.abspath(plot_folder), 'station_map.pdf')
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
plt.close('all')
def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_window: int = 3, ref_name: str = 'obs',
pred_name: str = 'CNN', season: str = "", forecast_path: str = None,
plot_name_affix: str = "", units: str = "ppb"):
"""
This plot was originally taken from Murphy, Brown and Chen (1989):
https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2
:param stations: stations to include in the plot (forecast data needs to be available already)
:param plot_folder: path to save the plot (default: current directory)
:param rolling_window: the rolling window mean will smooth the plot appearance (no smoothing in bin calculation,
this is only a cosmetic step, default: 3)
:param ref_name: name of the reference data series
:param pred_name: name of the investigated data series
:param season: season name to highlight if not empty
:param forecast_path: path to save the plot file
:param plot_name_affix: name to specify this plot (e.g. 'cali-ref', default: '')
:param units: units of the forecasted values (default: ppb)
"""
time = TimeTracking()
logging.debug(f"started plot_conditional_quantiles()")
# ignore warnings if nans appear in quantile grouping
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
# ignore warnings if mean is calculated on nans
warnings.filterwarnings("ignore", message="Mean of empty slice")
# ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar)
warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.")
def load_data():
logging.debug("... load data")
data_collector = []
for station in stations:
file = os.path.join(forecast_path, f"forecasts_{station}_test.nc")
data_tmp = xr.open_dataarray(file)
data_collector.append(data_tmp.loc[:, :, ['CNN', 'obs', 'OLS']].assign_coords(station=station))
return xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station')
def segment_data(data):
logging.debug("... segment data")
# combine index and station to multi index
data = data.stack(z=['index', 'station'])
# replace multi index by simple position index (order is not relevant anymore)
data.coords['z'] = range(len(data.coords['z']))
# segment data of pred_name into bins
data.loc[pred_name, ...] = data.loc[pred_name, ...].to_pandas().T.apply(pd.cut, bins=bins,
labels=bins[1:]).T.values
return data
def create_quantile_panel(data, q):
logging.debug("... create quantile panel")
# create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step)
quantile_panel = xr.DataArray(np.full([data.ahead.shape[0], len(q), bins[1:].shape[0]], np.nan),
coords=[data.ahead, q, bins[1:]], dims=['ahead', 'quantiles', 'categories'])
# ensure that the coordinates are in the right order
quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories')
# calculate for each bin of the pred_name data the quantiles of the ref_name data
for bin in bins[1:]:
mask = (data.loc[pred_name, ...] == bin)
quantile_panel.loc[..., bin] = data.loc[ref_name, ...].where(mask).quantile(q, dim=['z']).T
return quantile_panel
def labels(plot_type, data_unit="ppb"):
names = (f"forecast concentration (in {data_unit})", f"observed concentration (in {data_unit})")
if plot_type == "obs":
return names
else:
return names[::-1]
xlabel, ylabel = labels(ref_name, units)
opts = {"q": [.1, .25, .5, .75, .9], "linetype": [':', '-.', '--', '-.', ':'],
"legend": ['.10th and .90th quantile', '.25th and .75th quantile', '.50th quantile', 'reference 1:1'],
"xlabel": xlabel, "ylabel": ylabel}
# set name and path of the plot
base_name = "conditional_quantiles"
def add_affix(x): return f"_{x}" if len(x) > 0 else ""
plot_name = f"{base_name}{add_affix(season)}{add_affix(plot_name_affix)}_plot.pdf"
plot_path = os.path.join(os.path.abspath(plot_folder), plot_name)
# check forecast path
if forecast_path is None:
raise ValueError("Forecast path is not given but required.")
# load data and set data bins
orig_data = load_data()
bins = np.arange(0, math.ceil(orig_data.max().max()) + 1, 1).astype(int)
segmented_data = segment_data(orig_data)
quantile_panel = create_quantile_panel(segmented_data, q=opts["q"])
# init pdf output
pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path)
logging.debug(f"... plot path is {plot_path}")
# create plot for each time step ahead
y2_max = 0
for iteration, d in enumerate(segmented_data.ahead):
logging.debug(f"... plotting {d.values} time step(s) ahead")
# plot smoothed lines with rolling mean
smooth_data = quantile_panel.loc[d, ...].rolling(categories=rolling_window, center=True).mean().to_pandas().T
ax = smooth_data.plot(style=opts["linetype"], color='black', legend=False)
ax2 = ax.twinx()
# add reference line
ax.plot([0, bins.max()], [0, bins.max()], color='k', label='reference 1:1', linewidth=.8)
# add histogram of the segmented data (pred_name)
handles, labels = ax.get_legend_handles_labels()
segmented_data.loc[pred_name, d, :].to_pandas().hist(bins=bins, ax=ax2, color='k', alpha=.3, grid=False,
rwidth=1)
# add legend
plt.legend(handles[:3] + [handles[-1]], opts["legend"], loc='upper left', fontsize='large')
# adjust limits and set labels
ax.set(xlim=(0, bins.max()), ylim=(0, bins.max()))
ax.set_xlabel(opts["xlabel"], fontsize='x-large')
ax.tick_params(axis='x', which='major', labelsize=15)
ax.set_ylabel(opts["ylabel"], fontsize='x-large')
ax.tick_params(axis='y', which='major', labelsize=15)
ax2.yaxis.label.set_color('gray')
ax2.tick_params(axis='y', colors='gray')
ax2.yaxis.labelpad = -15
ax2.set_yscale('log')
if iteration == 0:
y2_max = ax2.get_ylim()[1] + 100
ax2.set(ylim=(0, y2_max * 10 ** 8), yticks=np.logspace(0, 4, 5))
ax2.set_ylabel(' sample size', fontsize='x-large')
ax2.tick_params(axis='y', which='major', labelsize=15)
# set title and save current figure
title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}"
plt.title(title)
pdf_pages.savefig()
# close all open figures / plots
pdf_pages.close()
plt.close('all')
logging.info(f"plot_conditional_quantiles() finished after {time}")
class PlotClimatologicalSkillScore(RunEnvironment):
"""
Create plot of climatological skill score after Murphy (1988) as box plot over all stations. A forecast time step
(called "ahead") is separately shown to highlight the differences for each prediction time step. Either each single
term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True,
default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with
name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi.
"""
def __init__(self, data: Dict, plot_folder: str = ".", score_only: bool = True, extra_name_tag: str = "",
model_setup: str = ""):
"""
Sets attributes and create plot
:param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms.
:param plot_folder: path to save the plot (default: current directory)
:param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True)
:param extra_name_tag: additional tag that can be included in the plot name (default "")
:param model_setup: architecture type to specify plot name (default "CNN")
"""
super().__init__()
self._labels = None
self._data = self._prepare_data(data, score_only)
self._plot(plot_folder, score_only, extra_name_tag, model_setup)
def _prepare_data(self, data: Dict, score_only: bool) -> pd.DataFrame:
"""
Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set
plot labels depending on the lead time dimensions.
:param data: dictionary with station names as keys and 2D xarrays as values
:param score_only: if true only scores of CASE I to IV are relevant
:return: pre-processed data set
"""
data = helpers.dict_to_xarray(data, "station")
self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
if score_only:
data = data.loc[:, ["CASE I", "CASE II", "CASE III", "CASE IV"], :]
return data.to_dataframe("data").reset_index(level=[0, 1, 2])
def _label_add(self, score_only: bool):
"""
Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True).
:param score_only: if false all terms are relevant, otherwise only CASE I to IV
:return: additional label
"""
return "" if score_only else "terms and "
def _plot(self, plot_folder, score_only, extra_name_tag, model_setup):
"""
Main plot function to plot climatological skill score.
:param plot_folder: path to save the plot
:param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms
:param extra_name_tag: additional tag that can be included in the plot name
:param model_setup: architecture type to specify plot name
"""
fig, ax = plt.subplots()
if not score_only:
fig.set_size_inches(11.7, 8.27)
sns.boxplot(x="terms", y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d",
showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
ax.axhline(y=0, color="grey", linewidth=.5)
ax.set(ylabel=f"{self._label_add(score_only)}skill score", xlabel="", title="summary of all stations")
handles, _ = ax.get_legend_handles_labels()
ax.legend(handles, self._labels)
plt.tight_layout()
self._save(plot_folder, extra_name_tag, model_setup)
@staticmethod
def _save(plot_folder, extra_name_tag, model_setup):
"""
Standard save method to store plot locally. The name of this plot is dynamic. It includes the model setup like
'CNN' and can additionally be adjusted using an extra name tag.
:param plot_folder: path to save the plot
:param extra_name_tag: additional tag that can be included in the plot name
:param model_setup: architecture type to specify plot name
"""
plot_name = os.path.join(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
plt.close('all')
class PlotCompetitiveSkillScore(RunEnvironment):
"""
Create competitive skill score for the given model setup and the reference models ordinary least squared ("ols") and
the persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name
skill_score_competitive_{model_setup}.pdf and resolution of 500dpi.
"""
def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="CNN"):
"""
:param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
calculated comparisons for cnn, persistence and ols.
:param plot_folder: path to save the plot (default: current directory)
:param model_setup: architecture type (default "CNN")
"""
super().__init__()
self._labels = None
self._data = self._prepare_data(data)
self._plot(plot_folder, model_setup)
def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
"""
Reformat given data and create plot labels. Introduces the dimensions stations and comparison
:param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre-
calculated comparisons for cnn, persistence and ols.
:return: processed data
"""
data = pd.concat(data, axis=0)
data = xr.DataArray(data, dims=["stations", "ahead"]).unstack("stations")
data = data.rename({"stations_level_0": "stations", "stations_level_1": "comparison"})
data = data.to_dataframe("data").unstack(level=1).swaplevel()
data.columns = data.columns.levels[1]
self._labels = [str(i) + "d" for i in data.index.levels[1].values]
return data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data")
def _plot(self, plot_folder, model_setup):
"""
Main plot function to plot skill scores of the comparisons cnn-persi, ols-persi and cnn-ols.
:param plot_folder: path to save the plot
:param model_setup:
:return: architecture type to specify plot name
"""
fig, ax = plt.subplots()
sns.boxplot(x="comparison", y="data", hue="ahead", data=self._data, whis=1., ax=ax, palette="Blues_d",
showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."},
order=["cnn-persi", "ols-persi", "cnn-ols"])
ax.axhline(y=0, color="grey", linewidth=.5)
ax.set(ylabel="skill score", xlabel="competing models", title="summary of all stations", ylim=self._ylim())
handles, _ = ax.get_legend_handles_labels()
ax.legend(handles, self._labels)
plt.tight_layout()
self._save(plot_folder, model_setup)
def _ylim(self) -> Tuple[float, float]:
"""
Calculate y-axis limits from data. Lower is the minimum of either 0 or data's minimum (reduced by small
subtrahend) and upper limit is data's maximum (increased by a small addend).
:return:
"""
lower = np.min([0, helpers.float_round(self._data.min()[2], 2) - 0.1])
upper = helpers.float_round(self._data.max()[2], 2) + 0.1
return lower, upper
@staticmethod
def _save(plot_folder, model_setup):
"""
Standard save method to store plot locally. The name of this plot is dynamic by including the model setup.
:param plot_folder: path to save the plot
:param model_setup: architecture type to specify plot name
"""
plot_name = os.path.join(plot_folder, f"skill_score_competitive_{model_setup}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
plt.close()
class PlotBootstrapSkillScore(RunEnvironment):
"""
Create plot of climatological skill score after Murphy (1988) as box plot over all stations. A forecast time step
(called "ahead") is separately shown to highlight the differences for each prediction time step. Either each single
term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True,
default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with
name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi.
"""
def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = ""):
"""
Sets attributes and create plot
:param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms.
:param plot_folder: path to save the plot (default: current directory)
:param model_setup: architecture type to specify plot name (default "CNN")
"""
super().__init__()
self._labels = None
self._data = self._prepare_data(data)
self._plot(plot_folder, model_setup)
def _prepare_data(self, data: Dict) -> pd.DataFrame:
"""
Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set
plot labels depending on the lead time dimensions.
:param data: dictionary with station names as keys and 2D xarrays as values
:return: pre-processed data set
"""
data = helpers.dict_to_xarray(data, "station")
self._labels = [str(i) + "d" for i in data.coords["ahead"].values]
return data.to_dataframe("data").reset_index(level=[0, 1, 2])
def _label_add(self, score_only: bool):
"""
Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True).
:param score_only: if false all terms are relevant, otherwise only CASE I to IV
:return: additional label
"""
return "" if score_only else "terms and "
def _plot(self, plot_folder, model_setup):
"""
Main plot function to plot climatological skill score.
:param plot_folder: path to save the plot
:param model_setup: architecture type to specify plot name
"""
fig, ax = plt.subplots()
sns.boxplot(x="boot_var", y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d",
showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."})
ax.axhline(y=0, color="grey", linewidth=.5)
ax.set(ylabel=f"skill score", xlabel="", title="summary of all stations")
handles, _ = ax.get_legend_handles_labels()
ax.legend(handles, self._labels)
plt.tight_layout()
self._save(plot_folder, model_setup)
@staticmethod
def _save(plot_folder, model_setup):
"""
Standard save method to store plot locally. The name of this plot is dynamic. It includes the model setup like
'CNN' and can additionally be adjusted using an extra name tag.
:param plot_folder: path to save the plot
:param model_setup: architecture type to specify plot name
"""
plot_name = os.path.join(plot_folder, f"skill_score_bootstrap_{model_setup}.pdf")
logging.debug(f"... save plot to {plot_name}")
plt.savefig(plot_name, dpi=500)
plt.close('all')
class PlotTimeSeries(RunEnvironment):
def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = ".",
sampling="daily"):
super().__init__()
self._data_path = data_path
self._data_name = name
self._stations = stations
self._window_lead_time = self._get_window_lead_time(window_lead_time)
self._sampling = self._get_sampling(sampling)
self._plot(plot_folder)
@staticmethod
def _get_sampling(sampling):
if sampling == "daily":
return "D"
elif sampling == "hourly":
return "h"
def _get_window_lead_time(self, window_lead_time: int):
"""
Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from
data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number
of ahead dimensions in data is lower than the given lead time, data's lead time is used.
:param window_lead_time: lead time from arguments to validate
:return: validated lead time, comes either from given argument or from data itself
"""
ahead_steps = len(self._load_data(self._stations[0]).ahead)
if window_lead_time is None:
window_lead_time = ahead_steps
return min(ahead_steps, window_lead_time)
def _load_data(self, station):
logging.debug(f"... preprocess station {station}")
file_name = os.path.join(self._data_path, self._data_name % station)
data = xr.open_dataarray(file_name)
return data.sel(type=["CNN", "obs"])
def _plot(self, plot_folder):
pdf_pages = self._create_pdf_pages(plot_folder)
start, end = self._get_time_range(self._load_data(self._stations[0]))
for pos, station in enumerate(self._stations):
data = self._load_data(station)
fig, axes, factor = self._create_subplots(start, end)
nan_list = []
for i_year in range(end - start + 1):
data_year = data.sel(index=f"{start + i_year}")
for i_half_of_year in range(factor):
pos = factor * i_year + i_half_of_year
plot_data = self._create_plot_data(data_year, factor, i_half_of_year)
self._plot_obs(axes[pos], plot_data)
self._plot_ahead(axes[pos], plot_data)
if np.isnan(plot_data.values).all():
nan_list.append(pos)
self._clean_up_axes(nan_list, axes, fig)
self._save_page(station, pdf_pages)
pdf_pages.close()
plt.close('all')
@staticmethod
def _clean_up_axes(nan_list, axes, fig):
for i in reversed(nan_list):
fig.delaxes(axes[i])
@staticmethod
def _save_page(station, pdf_pages):
plt.suptitle(station)
plt.legend()
plt.tight_layout()
pdf_pages.savefig(dpi=500)
@staticmethod
def _create_plot_data(data, factor, running_index):
if factor > 1:
if running_index == 0:
data = data.where(data["index.month"] < 7)
else:
data = data.where(data["index.month"] >= 7)
return data
def _create_subplots(self, start, end):
factor = 1
if self._sampling == "h":
factor = 2
f, ax = plt.subplots((end - start + 1) * factor, sharey=True, figsize=(50, 30))
return f, ax, factor
def _plot_ahead(self, ax, data):
color = sns.color_palette("Blues_d", self._window_lead_time).as_hex()
for ahead in data.coords["ahead"].values:
plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze().shift(index=ahead)
label = f"{ahead}{self._sampling}"
ax.plot(plot_data, color=color[ahead-1], label=label)
def _plot_obs(self, ax, data):
ahead = 1
obs_data = data.sel(type="obs", ahead=ahead).shift(index=ahead)
# index = data.index + np.timedelta64(1, self._sampling)
# ax.plot(index, obs_data.values, color=matplotlib.colors.cnames["green"], label="obs")
ax.plot(obs_data, color=matplotlib.colors.cnames["green"], label="obs")
@staticmethod
def _get_time_range(data):
def f(x, f_x):
return pd.to_datetime(f_x(x.index.values)).year
return f(data, min), f(data, max)
@staticmethod
def _create_pdf_pages(plot_folder):
"""
Standard save method to store plot locally. The name of this plot is static.
:param plot_folder: path to save the plot
"""
plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf')
logging.debug(f"... save plot to {plot_name}")
return matplotlib.backends.backend_pdf.PdfPages(plot_name)