diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 469b00d7da8954d1e96ab10f5372b338f9fa8b3e..03f0c5c5c12fd22e7bcb349662adc6c3e3154549 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -7,6 +7,7 @@ import math import warnings from src import helpers from src.helpers import TimeTracking +from src.run_modules.run_environment import RunEnvironment import numpy as np import xarray as xr @@ -24,59 +25,76 @@ from typing import Dict, List logging.getLogger('matplotlib').setLevel(logging.WARNING) -def plot_monthly_summary(stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None, - plot_folder: str = "."): +class PlotMonthlySummary(RunEnvironment): """ Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. The plot is saved in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution. - :param stations: all stations to plot - :param data_path: path, where the data is located - :param name: full name of the local files with a % as placeholder for the station name - :param target_var: display name of the target variable on plot's axis - :param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given - the maximum lead time from data is used. (default None -> use maximum lead time from data). - :param plot_folder: path to save the plot (default: current directory) """ - logging.debug("run plot_monthly_summary()") - forecasts = None - - for station in stations: - logging.debug(f"... preprocess station {station}") - file_name = os.path.join(data_path, name % station) - data = xr.open_dataarray(file_name) - - data_cnn = data.sel(type="CNN").squeeze() - data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values] - - data_orig = data.sel(type="orig", ahead=1).squeeze() - data_orig.coords["ahead"] = "orig" - - data_concat = xr.concat([data_orig, data_cnn], dim="ahead") - data_concat = data_concat.drop("type") - - data_concat.index.values = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1 - data_concat = data_concat.clip(min=0) - - forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat - - ahead_steps = len(forecasts.ahead) - if window_lead_time is None: - window_lead_time = ahead_steps - window_lead_time = min(ahead_steps, window_lead_time) - - forecasts = forecasts.to_dataset(name='values').to_dask_dataframe() - logging.debug("... start plotting") - ax = sns.boxplot(x='index', y='values', hue='ahead', data=forecasts.compute(), whis=1., - palette=[matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", - window_lead_time).as_hex(), - flierprops={'marker': '.', 'markersize': 1}, showmeans=True, - meanprops={'markersize': 1, 'markeredgecolor': 'k'}) - ax.set(xlabel='month', ylabel=f'{target_var}') - plt.tight_layout() - plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf') - logging.debug(f"... save plot to {plot_name}") - plt.savefig(plot_name, dpi=500) - plt.close('all') + def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None, + plot_folder: str = "."): + """ + Sets attributes and create plot + :param stations: all stations to plot + :param data_path: path, where the data is located + :param name: full name of the local files with a % as placeholder for the station name + :param target_var: display name of the target variable on plot's axis + :param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given + the maximum lead time from data is used. (default None -> use maximum lead time from data). + :param plot_folder: path to save the plot (default: current directory) + """ + super().__init__() + self._data_path = data_path + self._data_name = name + self._data = self._get_data(stations) + self._window_lead_time = self._get_window_lead_time(window_lead_time) + self._plot(target_var, plot_folder) + + def _get_data(self, stations): + """ + pre-process data + :param stations: + :return: + """ + forecasts = None + for station in stations: + logging.debug(f"... preprocess station {station}") + file_name = os.path.join(self._data_path, self._data_name % station) + data = xr.open_dataarray(file_name) + + data_cnn = data.sel(type="CNN").squeeze() + data_cnn.coords["ahead"].values = [f"{days}d" for days in data_cnn.coords["ahead"].values] + + data_orig = data.sel(type="orig", ahead=1).squeeze() + data_orig.coords["ahead"] = "orig" + + data_concat = xr.concat([data_orig, data_cnn], dim="ahead") + data_concat = data_concat.drop("type") + + data_concat.index.values = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1 + data_concat = data_concat.clip(min=0) + + forecasts = xr.concat([forecasts, data_concat], 'index') if forecasts is not None else data_concat + return forecasts + + def _get_window_lead_time(self, window_lead_time): + ahead_steps = len(self._data.ahead) + if window_lead_time is None: + window_lead_time = ahead_steps + return min(ahead_steps, window_lead_time) + + def _plot(self, target_var, plot_folder): + data = self._data.to_dataset(name='values').to_dask_dataframe() + logging.debug("... start plotting") + color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() + ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1., palette=color_palette, + flierprops={'marker': '.', 'markersize': 1}, showmeans=True, + meanprops={'markersize': 1, 'markeredgecolor': 'k'}) + ax.set(xlabel='month', ylabel=f'{target_var}') + plt.tight_layout() + plot_name = os.path.join(os.path.abspath(plot_folder), 'monthly_summary_box_plot.pdf') + logging.debug(f"... save plot to {plot_name}") + plt.savefig(plot_name, dpi=500) + plt.close('all') def plot_station_map(generators: Dict, plot_folder: str = "."): diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index ecbc415f8b463690dcc75d7731d8ec859d74da33..c8fde35cb39df00e94aeeb2ca218e6c65c248b54 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -15,8 +15,8 @@ from src.data_handling.data_distributor import Distributor from src.data_handling.data_generator import DataGenerator from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src import statistics -from src.plotting.postprocessing_plotting import plot_monthly_summary, plot_station_map, plot_conditional_quantiles, \ - plot_climatological_skill_score, plot_competitive_skill_score +from src.plotting.postprocessing_plotting import plot_station_map, plot_conditional_quantiles, \ + plot_climatological_skill_score, plot_competitive_skill_score, PlotMonthlySummary from src.datastore import NameNotFoundInDataStore @@ -62,8 +62,8 @@ class PostProcessing(RunEnvironment): plot_conditional_quantiles(self.test_data.stations, pred_name="orig", ref_name="CNN", forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) plot_station_map(generators={'b': self.test_data}, plot_folder=self.plot_path) - plot_monthly_summary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var, - plot_folder=self.plot_path) + PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", target_var, + plot_folder=self.plot_path) # plot_climatological_skill_score(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") plot_climatological_skill_score(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, diff --git a/src/statistics.py b/src/statistics.py index fd8491748f7ebd30d7056af9cc1ce162c5743881..df73784df830d5f7b96bf0fcd18a65d362516f12 100644 --- a/src/statistics.py +++ b/src/statistics.py @@ -136,7 +136,7 @@ class SkillScores(RunEnvironment): observation = data.sel(type=observation_name) forecast = data.sel(type=forecast_name) reference = data.sel(type=reference_name) - mse = statistics.mean_squared_error + mse = mean_squared_error skill_score = 1 - mse(observation, forecast) / mse(observation, reference) return skill_score.values @@ -216,4 +216,4 @@ class SkillScores(RunEnvironment): for month in mu.month: monthly_mean[monthly_mean.index.dt.month == month, :] = mu[mu.month == month].values - return monthly_mean \ No newline at end of file + return monthly_mean