diff --git a/requirements.txt b/requirements.txt index f6c1eb24615a501cd0af4cae9f8bbc5c015a3cb7..270c084865fbff00e6346b5f267c8d939a1d9902 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ seaborn dask==0.20.2 toolz # for dask cloudpickle # for dask -cython +cython==0.29.14 pyshp six pyproj diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 9ad7f87f5324dfb57c4cd4352e9daef19a493db1..338120591ce4d05b86c208dc6672a5e51a48d86f 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -3,13 +3,19 @@ __date__ = '2019-12-17' import os import logging +import math +import warnings +import numpy as np import xarray as xr +import pandas as pd + import matplotlib import seaborn as sns import matplotlib.pyplot as plt import cartopy.crs as ccrs import cartopy.feature as cfeature +from matplotlib.backends.backend_pdf import PdfPages logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -42,8 +48,8 @@ def plot_monthly_summary(stations, data_path, name: str, window_lead_time, targe logging.debug("... start plotting") ax = sns.boxplot(x='index', y='values', hue='ahead', data=forecasts.compute(), whis=1., palette=[matplotlib.colors.cnames["green"]]+sns.color_palette("Blues_d", window_lead_time).as_hex(), - flierprops={'marker':'.', 'markersize':1}, - showmeans=True, meanprops={'markersize':1, 'markeredgecolor':'k'}) + flierprops={'marker': '.', 'markersize': 1}, + showmeans=True, meanprops={'markersize': 1, 'markeredgecolor': 'k'}) ax.set(xlabel='month', ylabel=f'{target_var}') plt.tight_layout() plot_path = os.path.join(os.path.abspath(plot_folder), 'test_monthly_box.pdf') @@ -79,3 +85,74 @@ def station_map(generators, plot_folder="."): plot_path = os.path.join(os.path.abspath(plot_folder), 'test_map_plot.pdf') plt.savefig(plot_path) plt.close('all') + + +def plot_conditional_quantiles(stations, plot_folder=".", q=None, rolling_window=3, ref_name='orig', pred_name='CNN', season="", forecast_path=None): + # ignore warnings if nans appear in quantile grouping + warnings.filterwarnings("ignore", message="All-NaN slice encountered") + # ignore warnings if mean is calculated on nans + warnings.filterwarnings("ignore", message="Mean of empty slice") + # ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar) + warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.") + + plot_name = f"test_conditional_quantiles{f'_{season}' if len(season) > 0 else ''}" + plot_path = os.path.join(os.path.abspath(plot_folder), f"{plot_name}_plot.pdf") + + if q is None: + q = [.1, .25, .5, .75, .9] + + if forecast_path is None: + raise ValueError("Forecast path is not given but required.") + + data = [] + for station in stations: + file = os.path.join(forecast_path, f"forecasts_{station}_test.nc") + data_tmp = xr.open_dataarray(file) + data.append(data_tmp.loc[:, :, ['CNN', 'orig', 'OLS']].assign_coords(station=station)) + data = xr.concat(data, dim='station').transpose('index', 'type', 'ahead', 'station') + + linetype = [':', '-.', '--', '-.', ':'] + bins = np.arange(0, math.ceil(data.max().max()) + 1, 1).astype(int) + xlabel = 'forecast concentration (in ppb)' + ylabel = 'observed concentration (in ppb)' + + data = data.stack(z=['index', 'station']) + data.coords['z'] = range(len(data.coords['z'])) + data.loc[pred_name, ...] = data.loc[pred_name, ...].to_pandas().T.apply(pd.cut, bins=bins, labels=bins[1:]).T.values + quantile_panel = xr.DataArray(np.full([data.ahead.shape[0], len(q), bins[1:].shape[0]], np.nan), + coords=[data.ahead, q, bins[1:]], dims=['ahead', 'quantiles', 'categories']) + quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories') + for bin in bins[1:]: + quantile_panel.loc[..., bin] = data.loc[ref_name, ...].where(data.loc[pred_name, ...] == bin).quantile(q, dim=['z']).T + pdf_pages = PdfPages(plot_path) + + y2_max = 0 + for iteration, d in enumerate(data.ahead): + logging.debug(f"plotting {d.values} time step(s) ahead") + ax = quantile_panel.loc[d, ...].rolling(categories=rolling_window, center=True).mean().to_pandas().T.plot(style=linetype, color='black', legend=False) + ax2 = ax.twinx() + ax.plot([0, bins.max()], [0, bins.max()], color='k', label='reference 1:1', linewidth=.8) + handles, labels = ax.get_legend_handles_labels() + data.loc[pred_name, d, :].to_pandas().hist(bins=bins, ax=ax2, color='k', alpha=.3, grid=False, rwidth=1) + plt.legend(handles[:3]+[handles[-1]], ('.10th and .90th quantiles', '.25th and .75th quantiles', '.50th quantile', 'reference 1:1'), loc='upper left', fontsize='large') + ax.set(xlim=(0, bins.max()), ylim=(0, bins.max())) + ax.set_xlabel(xlabel, fontsize='x-large') + ax.tick_params(axis='x', which='major', labelsize=15) + ax.set_ylabel(ylabel, fontsize='x-large') + ax.tick_params(axis='y', which='major', labelsize=15) + ax2.yaxis.label.set_color('gray') + ax2.tick_params(axis='y', colors='gray') + ax2.yaxis.labelpad = -15 + ax2.set_yscale('log') + if iteration == 0: + y2_max = ax2.get_ylim()[1]+100 + ax2.set(ylim=(0, y2_max*10**8), yticks=np.logspace(0, 4, 5)) + ax2.set_ylabel(' sample size', fontsize='x-large') + ax2.tick_params(axis='y', which='major', labelsize=15) + title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}" + plt.title(title) + pdf_pages.savefig() + pdf_pages.close() + plt.close('all') + + diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 96c8a40b772aadc4e25d6c54460c6e3ecd99336e..7ab6abf9857c647617ec42716d471412e9c81392 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -17,7 +17,7 @@ from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src import statistics from src import helpers from src.helpers import TimeTracking -from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_climsum_boxplot, station_map +from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_climsum_boxplot, station_map, plot_conditional_quantiles class PostProcessing(RunEnvironment): @@ -43,9 +43,11 @@ class PostProcessing(RunEnvironment): path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") target_var = self.data_store.get("target_var", "general") - station_map(generators={'b': self.test_data}, plot_folder=self.plot_path) - plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var, - plot_folder=self.plot_path) + + plot_conditional_quantiles(self.test_data.stations, plot_folder=self.plot_path, forecast_path=self.data_store.get("forecast_path", "general")) + # station_map(generators={'b': self.test_data}, plot_folder=self.plot_path) + # plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var, + # plot_folder=self.plot_path) # plot_climsum_boxplot() def calculate_test_score(self):