diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 1a8237e31a6782cf7e53a088012cb67a1a125747..8c8ea98e9f356be0b9064afcfcab73d00df67311 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -5,6 +5,7 @@ import os import logging import math import warnings +from src.helpers import TimeTracking import numpy as np import xarray as xr @@ -87,7 +88,23 @@ def plot_station_map(generators, plot_folder="."): plt.close('all') -def plot_conditional_quantiles(stations, plot_folder=".", q=None, rolling_window=3, ref_name='orig', pred_name='CNN', season="", forecast_path=None): +def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_window: int = 3, ref_name: str = 'orig', + pred_name: str = 'CNN', season: str = "", forecast_path: str = None): + """ + This plot was originally taken from Murphy, Brown and Chen (1989): + https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2 + + :param stations: stations to include in the plot (forecast data needs to be available already) + :param plot_folder: path to save the plot (default: current directory) + :param rolling_window: the rolling window mean will smooth the plot appearance (no smoothing in bin calculation, + this is only a cosmetic step, default: 3) + :param ref_name: name of the reference data series + :param pred_name: name of the investigated data series + :param season: season name to highlight if not empty + :param forecast_path: path to save the plot file + """ + time = TimeTracking() + logging.debug(f"started plot_conditional_quantiles()") # ignore warnings if nans appear in quantile grouping warnings.filterwarnings("ignore", message="All-NaN slice encountered") # ignore warnings if mean is calculated on nans @@ -95,50 +112,85 @@ def plot_conditional_quantiles(stations, plot_folder=".", q=None, rolling_window # ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar) warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.") + def load_data(): + logging.debug("... load data") + data_collector = [] + for station in stations: + file = os.path.join(forecast_path, f"forecasts_{station}_test.nc") + data_tmp = xr.open_dataarray(file) + data_collector.append(data_tmp.loc[:, :, ['CNN', 'orig', 'OLS']].assign_coords(station=station)) + return xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station') + + def segment_data(data): + logging.debug("... segment data") + # combine index and station to multi index + data = data.stack(z=['index', 'station']) + # replace multi index by simple position index (order is not relevant anymore) + data.coords['z'] = range(len(data.coords['z'])) + # segment data of pred_name into bins + data.loc[pred_name, ...] = data.loc[pred_name, ...].to_pandas().T.apply(pd.cut, bins=bins, + labels=bins[1:]).T.values + return data + + def create_quantile_panel(data, q): + logging.debug("... create quantile panel") + # create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step) + quantile_panel = xr.DataArray(np.full([data.ahead.shape[0], len(q), bins[1:].shape[0]], np.nan), + coords=[data.ahead, q, bins[1:]], dims=['ahead', 'quantiles', 'categories']) + # ensure that the coordinates are in the right order + quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories') + # calculate for each bin of the pred_name data the quantiles of the ref_name data + for bin in bins[1:]: + mask = (data.loc[pred_name, ...] == bin) + quantile_panel.loc[..., bin] = data.loc[ref_name, ...].where(mask).quantile(q, dim=['z']).T + + return quantile_panel + + opts = {"q": [.1, .25, .5, .75, .9], + "linetype": [':', '-.', '--', '-.', ':'], + "legend": ['.10th and .90th quantile', '.25th and .75th quantile', '.50th quantile', 'reference 1:1'], + "xlabel": "forecast concentration (in ppb)", + "ylabel": "observed concentration (in ppb)"} + + # set name and path of the plot plot_name = f"test_conditional_quantiles{f'_{season}' if len(season) > 0 else ''}" plot_path = os.path.join(os.path.abspath(plot_folder), f"{plot_name}_plot.pdf") - if q is None: - q = [.1, .25, .5, .75, .9] - + # check forecast path if forecast_path is None: raise ValueError("Forecast path is not given but required.") - data = [] - for station in stations: - file = os.path.join(forecast_path, f"forecasts_{station}_test.nc") - data_tmp = xr.open_dataarray(file) - data.append(data_tmp.loc[:, :, ['CNN', 'orig', 'OLS']].assign_coords(station=station)) - data = xr.concat(data, dim='station').transpose('index', 'type', 'ahead', 'station') - - linetype = [':', '-.', '--', '-.', ':'] - bins = np.arange(0, math.ceil(data.max().max()) + 1, 1).astype(int) - xlabel = 'forecast concentration (in ppb)' - ylabel = 'observed concentration (in ppb)' - - data = data.stack(z=['index', 'station']) - data.coords['z'] = range(len(data.coords['z'])) - data.loc[pred_name, ...] = data.loc[pred_name, ...].to_pandas().T.apply(pd.cut, bins=bins, labels=bins[1:]).T.values - quantile_panel = xr.DataArray(np.full([data.ahead.shape[0], len(q), bins[1:].shape[0]], np.nan), - coords=[data.ahead, q, bins[1:]], dims=['ahead', 'quantiles', 'categories']) - quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories') - for bin in bins[1:]: - quantile_panel.loc[..., bin] = data.loc[ref_name, ...].where(data.loc[pred_name, ...] == bin).quantile(q, dim=['z']).T - pdf_pages = PdfPages(plot_path) + # load data and set data bins + orig_data = load_data() + bins = np.arange(0, math.ceil(orig_data.max().max()) + 1, 1).astype(int) + segmented_data = segment_data(orig_data) + quantile_panel = create_quantile_panel(segmented_data, q=opts["q"]) + + # init pdf output + pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) + logging.debug(f"... plot path is {plot_path}") + # create plot for each time step ahead y2_max = 0 - for iteration, d in enumerate(data.ahead): - logging.debug(f"plotting {d.values} time step(s) ahead") - ax = quantile_panel.loc[d, ...].rolling(categories=rolling_window, center=True).mean().to_pandas().T.plot(style=linetype, color='black', legend=False) + for iteration, d in enumerate(segmented_data.ahead): + logging.debug(f"... plotting {d.values} time step(s) ahead") + # plot smoothed lines with rolling mean + smooth_data = quantile_panel.loc[d, ...].rolling(categories=rolling_window, center=True).mean().to_pandas().T + ax = smooth_data.plot(style=opts["linetype"], color='black', legend=False) ax2 = ax.twinx() + # add reference line ax.plot([0, bins.max()], [0, bins.max()], color='k', label='reference 1:1', linewidth=.8) + # add histogram of the segmented data (pred_name) handles, labels = ax.get_legend_handles_labels() - data.loc[pred_name, d, :].to_pandas().hist(bins=bins, ax=ax2, color='k', alpha=.3, grid=False, rwidth=1) - plt.legend(handles[:3]+[handles[-1]], ('.10th and .90th quantiles', '.25th and .75th quantiles', '.50th quantile', 'reference 1:1'), loc='upper left', fontsize='large') + segmented_data.loc[pred_name, d, :].to_pandas().hist(bins=bins, ax=ax2, color='k', + alpha=.3, grid=False, rwidth=1) + # add legend + plt.legend(handles[:3]+[handles[-1]], opts["legend"], loc='upper left', fontsize='large') + # adjust limits and set labels ax.set(xlim=(0, bins.max()), ylim=(0, bins.max())) - ax.set_xlabel(xlabel, fontsize='x-large') + ax.set_xlabel(opts["xlabel"], fontsize='x-large') ax.tick_params(axis='x', which='major', labelsize=15) - ax.set_ylabel(ylabel, fontsize='x-large') + ax.set_ylabel(opts["ylabel"], fontsize='x-large') ax.tick_params(axis='y', which='major', labelsize=15) ax2.yaxis.label.set_color('gray') ax2.tick_params(axis='y', colors='gray') @@ -149,10 +201,11 @@ def plot_conditional_quantiles(stations, plot_folder=".", q=None, rolling_window ax2.set(ylim=(0, y2_max*10**8), yticks=np.logspace(0, 4, 5)) ax2.set_ylabel(' sample size', fontsize='x-large') ax2.tick_params(axis='y', which='major', labelsize=15) + # set title and save current figure title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}" plt.title(title) pdf_pages.savefig() + # close all open figures / plots pdf_pages.close() plt.close('all') - - + logging.info(f"plot_conditional_quantiles() finished after {time}")