diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index d25736f1a9ca47984ac513805a9b458ff09ff667..c3fb7abc9c51378552ed91d2fa6b69e08fa351e7 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1195,12 +1195,14 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover ax.set_ylim([ylims[0], ylims[1]*1.025]) ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})") ax.set_xticklabels(ax.get_xticklabels(), rotation=45) + ax.set_xlabel(None) elif orientation == "h": if apply_u_test: ax = self.set_significance_bars(asteriks, ax, data_table, orientation) ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})") xlims = list(ax.get_xlim()) ax.set_xlim([xlims[0], xlims[1] * 1.015]) + ax.set_ylabel(None) else: raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") text = f"n={n_boots}" @@ -1323,7 +1325,8 @@ class PlotTimeEvolutionMetric(AbstractPlotClass): class PlotSeasonalMSEStack(AbstractPlotClass): def __init__(self, data_path: str, plot_folder: str = ".", boot_dim="boots", ahead_dim="ahead", - sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$"): + sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$", + model_name: str = "NN", model_indicator: str = "nn", model_type_dim: str = "type"): """Set attributes and create plot.""" super().__init__(plot_folder, "seasonal_mse_stack_plot") self.plot_name_orig = "seasonal_mse_stack_plot" @@ -1331,13 +1334,13 @@ class PlotSeasonalMSEStack(AbstractPlotClass): self.season_dim = "season" self.error_unit = error_unit self.error_measure = error_measure - self._data = self._prepare_data(boot_dim, data_path) + self._data = self._prepare_data(boot_dim, data_path, model_type_dim, model_indicator, model_name) for orientation in ["horizontal", "vertical"]: for split_ahead in [True, False]: self._plot(ahead_dim, split_ahead, sampling, orientation) self._save(bbox_inches="tight") - def _prepare_data(self, boot_dim, data_path): + def _prepare_data(self, boot_dim, data_path, model_type_dim, model_indicator, model_name): season_dim = self.season_dim data = {} for season in ["total", "DJF", "MAM", "JJA", "SON"]: @@ -1353,7 +1356,9 @@ class PlotSeasonalMSEStack(AbstractPlotClass): xr_data = xr.Dataset(mean).to_array(season_dim) factor = xr_data.sel({season_dim: "total"}) / xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}).sum( season_dim) - return xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}) * factor + xr_data = xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}) * factor + xr_data[model_type_dim] = [v if v != model_indicator else model_name for v in xr_data[model_type_dim].values] + return xr_data @staticmethod def _get_target_sampling(sampling, pos): diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index d65a200161a7593fe03df5053328aa3f8cd77310..de58e9054aa1619ddb5b8fd1fb481b25bf089f5b 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -686,7 +686,9 @@ class PostProcessing(RunEnvironment): report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") PlotSeasonalMSEStack(data_path=report_path, plot_folder=self.plot_path, boot_dim=self.uncertainty_estimate_boot_dim, ahead_dim=self.ahead_dim, - sampling=self._sampling, error_measure="Mean Squared Error", error_unit=r"ppb$^2$") + sampling=self._sampling, error_measure="Mean Squared Error", error_unit=r"ppb$^2$", + model_indicator=self.forecast_indicator, model_name=self.model_display_name, + model_type_dim=self.model_type_dim) except Exception as e: logging.error(f"Could not create plot PlotSeasonalMSEStack due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")