Skip to content
Snippets Groups Projects
Commit a2ef0d08 authored by lukas leufen's avatar lukas leufen
Browse files

first implementation of plot_conditional_quantiles without documentation

parent e532578e
Branches
Tags
2 merge requests!37include new development,!27Lukas issue032 feat plotting postprocessing
Pipeline #28146 passed
...@@ -15,7 +15,7 @@ seaborn ...@@ -15,7 +15,7 @@ seaborn
dask==0.20.2 dask==0.20.2
toolz # for dask toolz # for dask
cloudpickle # for dask cloudpickle # for dask
cython cython==0.29.14
pyshp pyshp
six six
pyproj pyproj
......
...@@ -3,13 +3,19 @@ __date__ = '2019-12-17' ...@@ -3,13 +3,19 @@ __date__ = '2019-12-17'
import os import os
import logging import logging
import math
import warnings
import numpy as np
import xarray as xr import xarray as xr
import pandas as pd
import matplotlib import matplotlib
import seaborn as sns import seaborn as sns
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import cartopy.crs as ccrs import cartopy.crs as ccrs
import cartopy.feature as cfeature import cartopy.feature as cfeature
from matplotlib.backends.backend_pdf import PdfPages
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING)
...@@ -79,3 +85,74 @@ def station_map(generators, plot_folder="."): ...@@ -79,3 +85,74 @@ def station_map(generators, plot_folder="."):
plot_path = os.path.join(os.path.abspath(plot_folder), 'test_map_plot.pdf') plot_path = os.path.join(os.path.abspath(plot_folder), 'test_map_plot.pdf')
plt.savefig(plot_path) plt.savefig(plot_path)
plt.close('all') plt.close('all')
def plot_conditional_quantiles(stations, plot_folder=".", q=None, rolling_window=3, ref_name='orig', pred_name='CNN', season="", forecast_path=None):
# ignore warnings if nans appear in quantile grouping
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
# ignore warnings if mean is calculated on nans
warnings.filterwarnings("ignore", message="Mean of empty slice")
# ignore warnings for y tick = 0 on log scale (instead of 0.00001 or similar)
warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.")
plot_name = f"test_conditional_quantiles{f'_{season}' if len(season) > 0 else ''}"
plot_path = os.path.join(os.path.abspath(plot_folder), f"{plot_name}_plot.pdf")
if q is None:
q = [.1, .25, .5, .75, .9]
if forecast_path is None:
raise ValueError("Forecast path is not given but required.")
data = []
for station in stations:
file = os.path.join(forecast_path, f"forecasts_{station}_test.nc")
data_tmp = xr.open_dataarray(file)
data.append(data_tmp.loc[:, :, ['CNN', 'orig', 'OLS']].assign_coords(station=station))
data = xr.concat(data, dim='station').transpose('index', 'type', 'ahead', 'station')
linetype = [':', '-.', '--', '-.', ':']
bins = np.arange(0, math.ceil(data.max().max()) + 1, 1).astype(int)
xlabel = 'forecast concentration (in ppb)'
ylabel = 'observed concentration (in ppb)'
data = data.stack(z=['index', 'station'])
data.coords['z'] = range(len(data.coords['z']))
data.loc[pred_name, ...] = data.loc[pred_name, ...].to_pandas().T.apply(pd.cut, bins=bins, labels=bins[1:]).T.values
quantile_panel = xr.DataArray(np.full([data.ahead.shape[0], len(q), bins[1:].shape[0]], np.nan),
coords=[data.ahead, q, bins[1:]], dims=['ahead', 'quantiles', 'categories'])
quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories')
for bin in bins[1:]:
quantile_panel.loc[..., bin] = data.loc[ref_name, ...].where(data.loc[pred_name, ...] == bin).quantile(q, dim=['z']).T
pdf_pages = PdfPages(plot_path)
y2_max = 0
for iteration, d in enumerate(data.ahead):
logging.debug(f"plotting {d.values} time step(s) ahead")
ax = quantile_panel.loc[d, ...].rolling(categories=rolling_window, center=True).mean().to_pandas().T.plot(style=linetype, color='black', legend=False)
ax2 = ax.twinx()
ax.plot([0, bins.max()], [0, bins.max()], color='k', label='reference 1:1', linewidth=.8)
handles, labels = ax.get_legend_handles_labels()
data.loc[pred_name, d, :].to_pandas().hist(bins=bins, ax=ax2, color='k', alpha=.3, grid=False, rwidth=1)
plt.legend(handles[:3]+[handles[-1]], ('.10th and .90th quantiles', '.25th and .75th quantiles', '.50th quantile', 'reference 1:1'), loc='upper left', fontsize='large')
ax.set(xlim=(0, bins.max()), ylim=(0, bins.max()))
ax.set_xlabel(xlabel, fontsize='x-large')
ax.tick_params(axis='x', which='major', labelsize=15)
ax.set_ylabel(ylabel, fontsize='x-large')
ax.tick_params(axis='y', which='major', labelsize=15)
ax2.yaxis.label.set_color('gray')
ax2.tick_params(axis='y', colors='gray')
ax2.yaxis.labelpad = -15
ax2.set_yscale('log')
if iteration == 0:
y2_max = ax2.get_ylim()[1]+100
ax2.set(ylim=(0, y2_max*10**8), yticks=np.logspace(0, 4, 5))
ax2.set_ylabel(' sample size', fontsize='x-large')
ax2.tick_params(axis='y', which='major', labelsize=15)
title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}"
plt.title(title)
pdf_pages.savefig()
pdf_pages.close()
plt.close('all')
...@@ -17,7 +17,7 @@ from src.model_modules.linear_model import OrdinaryLeastSquaredModel ...@@ -17,7 +17,7 @@ from src.model_modules.linear_model import OrdinaryLeastSquaredModel
from src import statistics from src import statistics
from src import helpers from src import helpers
from src.helpers import TimeTracking from src.helpers import TimeTracking
from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_climsum_boxplot, station_map from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_climsum_boxplot, station_map, plot_conditional_quantiles
class PostProcessing(RunEnvironment): class PostProcessing(RunEnvironment):
...@@ -43,9 +43,11 @@ class PostProcessing(RunEnvironment): ...@@ -43,9 +43,11 @@ class PostProcessing(RunEnvironment):
path = self.data_store.get("forecast_path", "general") path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
target_var = self.data_store.get("target_var", "general") target_var = self.data_store.get("target_var", "general")
station_map(generators={'b': self.test_data}, plot_folder=self.plot_path)
plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var, plot_conditional_quantiles(self.test_data.stations, plot_folder=self.plot_path, forecast_path=self.data_store.get("forecast_path", "general"))
plot_folder=self.plot_path) # station_map(generators={'b': self.test_data}, plot_folder=self.plot_path)
# plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", window_lead_time, target_var,
# plot_folder=self.plot_path)
# plot_climsum_boxplot() # plot_climsum_boxplot()
def calculate_test_score(self): def calculate_test_score(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment