diff --git a/run.py b/run.py index 9809712876dc886007b042a52d7b46c027800faf..fc61ae8788f9f72bc555f1fb12d8b58bb5224937 100644 --- a/run.py +++ b/run.py @@ -15,8 +15,8 @@ from src.run_modules.training import Training def main(parser_args): with RunEnvironment(): - ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], - station_type='background', trainable=False, create_new_model=False, window_history_size=6, + ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', ], # 'DEBW001'], + station_type='background', trainable=False, create_new_model=True, window_history_size=6, create_new_bootstraps=True) PreProcessing() diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 14e3074a7d8f09bd597fb2fbf53a298d83ab6556..32a84f25fefb24f40a40681592f1c9a7c5c5afcf 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -188,6 +188,246 @@ class PlotStationMap(AbstractPlotClass): self._plot_stations(generators) +@TimeTrackingWrapper +class PlotConditionalQuantiles(AbstractPlotClass): + """ + This class creates conditional quantile plots as originally proposed by Murphy, Brown and Chen (1989) + + Link to paper: https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2 + """ + + def __init__(self, stations: List, data_pred_path: str, plot_folder: str = ".", plot_per_seasons=True, + rolling_window: int = 3, model_mame: str = "CNN", obs_name: str = "obs", **kwargs): + """ + + :param stations: all stations to plot + :param data_pred_path: path to dir which contains the forecasts as .nc files + :param plot_folder: path where the plots are stored + :param plot_per_seasons: if `True' create cond. quantile plots for seasons (DJF, MAM, JJA, SON) individually + :param rolling_window: smoothing of quantiles (3 is used by Murphy et al.) + :param model_mame: name of the model prediction as stored in netCDF file (for example "CNN") + :param obs_name: name of observation as stored in netCDF file (for example "obs") + :param kwargs: Some further arguments which are listed in self._opts + """ + + super().__init__(plot_folder, "conditional_quantiles") + + self._data_pred_path = data_pred_path + self._stations = stations + self._rolling_window = rolling_window + self._model_name = model_mame + self._obs_name = obs_name + + self._opts = {"q": kwargs.get("q", [.1, .25, .5, .75, .9]), + "linetype": kwargs.get("linetype", [':', '-.', '--', '-.', ':']), + "legend": kwargs.get("legend", + ['.10th and .90th quantile', '.25th and .75th quantile', '.50th quantile', + 'reference 1:1']), + "data_unit": kwargs.get("data_unit", "ppb"), + } + # self._data_unit = kwargs.get("data_unit", "ppb") + if plot_per_seasons is True: + self.seasons = ['DJF', 'MAM', 'JJA', 'SON'] + else: + self.seasons = "" + self._data = self._load_data() + self._bins = self._get_bins_from_rage_of_data() + + self._plot() + + def _load_data(self): + """ + This method loads forcast data + + :return: + """ + logging.debug("... load data") + data_collector = [] + for station in self._stations: + file = os.path.join(self._data_pred_path, f"forecasts_{station}_test.nc") + data_tmp = xr.open_dataarray(file) + data_collector.append(data_tmp.loc[:, :, [self._model_name, self._obs_name]].assign_coords(station=station)) + res = xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station') + return res + + def _segment_data(self, data, x_model): + """ + This method creates segmented data which is used for cond. quantile plots + + :param data: + :param x_model: + :return: + """ + logging.debug("... segment data") + # combine index and station to multi index + data = data.stack(z=['index', 'station']) + # replace multi index by simple position index (order is not relevant anymore) + data.coords['z'] = range(len(data.coords['z'])) + # segment data of x_model into bins + data.loc[x_model, ...] = data.loc[x_model, ...].to_pandas().T.apply(pd.cut, bins=self._bins, + labels=self._bins[1:]).T.values + return data + + @staticmethod + def _labels(plot_type, data_unit="ppb"): + """ + Helper method to correctly assign (x,y) labels to plots, depending on like-base or cali-ref factorization + + :param plot_type: + :param data_unit: + :return: + """ + names = (f"forecast concentration (in {data_unit})", f"observed concentration (in {data_unit})") + if plot_type == "obs": + return names + else: + return names[::-1] + + def _get_bins_from_rage_of_data(self): + """ + Get array of bins to use for quantiles + + :return: + """ + return np.arange(0, math.ceil(self._data.max().max()) + 1, 1).astype(int) + + def _create_quantile_panel(self, data, x_model, y_model): + """ + Clculate quantiles + + :param data: + :param x_model: + :param y_model: + :return: + """ + logging.debug("... create quantile panel") + # create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step) + quantile_panel = xr.DataArray( + np.full([data.ahead.shape[0], len(self._opts["q"]), self._bins[1:].shape[0]], np.nan), + coords=[data.ahead, self._opts["q"], self._bins[1:]], dims=['ahead', 'quantiles', 'categories']) + # ensure that the coordinates are in the right order + quantile_panel = quantile_panel.transpose('ahead', 'quantiles', 'categories') + # calculate for each bin of the pred_name data the quantiles of the ref_name data + for bin in self._bins[1:]: + mask = (data.loc[x_model, ...] == bin) + quantile_panel.loc[..., bin] = data.loc[y_model, ...].where(mask).quantile(self._opts["q"], + dim=['z']).T + return quantile_panel + + @staticmethod + def add_affix(x): + """ + Helper method to add additional information on plot name + + :param x: + :return: + """ + return f"_{x}" if len(x) > 0 else "" + + def _prepare_plots(self, data, x_model, y_model): + """ + Get segmented_data and quantile_panel + + :param data: + :param x_model: + :param y_model: + :return: + """ + segmented_data = self._segment_data(data, x_model) + quantile_panel = self._create_quantile_panel(segmented_data, x_model, y_model) + return segmented_data, quantile_panel + + def _plot(self): + """ + Main plot call + + :return: + """ + if len(self.seasons) > 0: + self._plot_seasons() + self._plot_all() + + def _plot_seasons(self): + """ + Seasonal plot call + + :return: + """ + for season in self.seasons: + self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._model_name, + y_model=self._obs_name, plot_name_affix="cali-ref", season=season) + self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._obs_name, + y_model=self._model_name, plot_name_affix="like-base", season=season) + + def _plot_all(self): + """ + Full plot call + + :return: + """ + self._plot_base(data=self._data, x_model=self._model_name, y_model=self._obs_name, plot_name_affix="cali-ref") + self._plot_base(data=self._data, x_model=self._obs_name, y_model=self._model_name, plot_name_affix="like-base") + + def _plot_base(self, data, x_model, y_model, plot_name_affix, season=""): + """ + Base method to create cond. quantile plots. Is called from _plot_all and _plot_seasonal + + :param data: data which is used to create cond. quantile plot + :param x_model: name of model on x axis (can also be obs) + :param y_model: name of model on y axis (can also be obs) + :param plot_name_affix: should be `cali-ref' or `like-base' + :param season: List of seasons to use + :return: + """ + segmented_data, quantile_panel = self._prepare_plots(data, x_model, y_model) + ylabel, xlabel = self._labels(x_model, self._opts["data_unit"]) + plot_name = f"{self.plot_name}{self.add_affix(season)}{self.add_affix(plot_name_affix)}_plot.pdf" + #f"{base_name}{add_affix(season)}{add_affix(plot_name_affix)}_plot.pdf" + plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) + pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) + logging.debug(f"... plot path is {plot_path}") + + # create plot for each time step ahead + y2_max = 0 + for iteration, d in enumerate(segmented_data.ahead): + logging.debug(f"... plotting {d.values} time step(s) ahead") + # plot smoothed lines with rolling mean + smooth_data = quantile_panel.loc[d, ...].rolling(categories=self._rolling_window, + center=True).mean().to_pandas().T + ax = smooth_data.plot(style=self._opts["linetype"], color='black', legend=False) + ax2 = ax.twinx() + # add reference line + ax.plot([0, self._bins.max()], [0, self._bins.max()], color='k', label='reference 1:1', linewidth=.8) + # add histogram of the segmented data (pred_name) + handles, labels = ax.get_legend_handles_labels() + segmented_data.loc[x_model, d, :].to_pandas().hist(bins=self._bins, ax=ax2, color='k', alpha=.3, grid=False, + rwidth=1) + # add legend + plt.legend(handles[:3] + [handles[-1]], self._opts["legend"], loc='upper left', fontsize='large') + # adjust limits and set labels + ax.set(xlim=(0, self._bins.max()), ylim=(0, self._bins.max())) + ax.set_xlabel(xlabel, fontsize='x-large') + ax.tick_params(axis='x', which='major', labelsize=15) + ax.set_ylabel(ylabel, fontsize='x-large') + ax.tick_params(axis='y', which='major', labelsize=15) + ax2.yaxis.label.set_color('gray') + ax2.tick_params(axis='y', colors='gray') + ax2.yaxis.labelpad = -15 + ax2.set_yscale('log') + if iteration == 0: + y2_max = ax2.get_ylim()[1] + 100 + ax2.set(ylim=(0, y2_max * 10 ** 8), yticks=np.logspace(0, 4, 5)) + ax2.set_ylabel(' sample size', fontsize='x-large') + ax2.tick_params(axis='y', which='major', labelsize=15) + # set title and save current figure + title = f"{d.values} time step(s) ahead{f' ({season})' if len(season) > 0 else ''}" + plt.title(title) + pdf_pages.savefig() + # close all open figures / plots + pdf_pages.close() + plt.close('all') + + @TimeTrackingWrapper def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_window: int = 3, ref_name: str = 'obs', pred_name: str = 'CNN', season: str = "", forecast_path: str = None, @@ -264,7 +504,7 @@ def plot_conditional_quantiles(stations: list, plot_folder: str = ".", rolling_w "xlabel": xlabel, "ylabel": ylabel} # set name and path of the plot - base_name = "conditional_quantiles" + base_name = "conditional_quantiles-orig" def add_affix(x): return f"_{x}" if len(x) > 0 else "" plot_name = f"{base_name}{add_affix(season)}{add_affix(plot_name_affix)}_plot.pdf" plot_path = os.path.join(os.path.abspath(plot_folder), plot_name) @@ -697,7 +937,6 @@ class PlotAvailability(AbstractPlotClass): plt_dict[summary_name].update({subset: t2}) return plt_dict - def _plot(self, plt_dict): # colors = {"train": "orange", "val": "blueishgreen", "test": "skyblue"} # color names colors = {"train": "#e69f00", "val": "#009e73", "test": "#56b4e9"} # hex code @@ -722,3 +961,17 @@ class PlotAvailability(AbstractPlotClass): handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()] lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) return lgd + + +if __name__ == "__main__": + stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] + path = "../../testrun_network/forecasts" + plt_path = "../../" + + con_quan_cls = PlotConditionalQuantiles(stations, path, plt_path) + + # con_quan_cls = PlotConditionalQuantiles(stations, data_pred_path=path, plot_name_affix="", pred_name="CNN", + # ref_name="obs", plot_folder=plt_path, seasons=None) + plot_conditional_quantiles(stations, pred_name="CNN", ref_name="obs", + forecast_path=path, plot_name_affix="cali-ref-orig", plot_folder=plt_path) + diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 8a962888ec0b789a14a24b20c97148e7a8315b30..b1f3afeae4e4c15ffdbbe0da252737256ebe5687 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -20,7 +20,7 @@ from src.helpers import TimeTracking from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.model_modules.model_class import AbstractModelClass from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability + PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles from src.plotting.postprocessing_plotting import plot_conditional_quantiles from src.run_modules.run_environment import RunEnvironment @@ -200,6 +200,7 @@ class PostProcessing(RunEnvironment): forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) plot_conditional_quantiles(self.test_data.stations, pred_name="obs", ref_name="CNN", forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) + PlotConditionalQuantiles(self.test_data.stations, data_pred_path=path, plot_folder=self.plot_path) if "PlotStationMap" in plot_list: PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) if "PlotMonthlySummary" in plot_list: