From b2003ea31db70cc2f74a39d7edea8bd8503fd05c Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 30 Aug 2022 14:24:44 +0200 Subject: [PATCH] new plot avail PlotSeasonalMSEStack, not in default plot list. \close #422 --- mlair/plotting/postprocessing_plotting.py | 87 +++++++++++++++++++++++ mlair/run_modules/post_processing.py | 12 +++- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index c7647ef5..d25736f1 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1319,6 +1319,93 @@ class PlotTimeEvolutionMetric(AbstractPlotClass): self._save() +@TimeTrackingWrapper +class PlotSeasonalMSEStack(AbstractPlotClass): + + def __init__(self, data_path: str, plot_folder: str = ".", boot_dim="boots", ahead_dim="ahead", + sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$"): + """Set attributes and create plot.""" + super().__init__(plot_folder, "seasonal_mse_stack_plot") + self.plot_name_orig = "seasonal_mse_stack_plot" + self._data_path = data_path + self.season_dim = "season" + self.error_unit = error_unit + self.error_measure = error_measure + self._data = self._prepare_data(boot_dim, data_path) + for orientation in ["horizontal", "vertical"]: + for split_ahead in [True, False]: + self._plot(ahead_dim, split_ahead, sampling, orientation) + self._save(bbox_inches="tight") + + def _prepare_data(self, boot_dim, data_path): + season_dim = self.season_dim + data = {} + for season in ["total", "DJF", "MAM", "JJA", "SON"]: + if season == "total": + file_name = "uncertainty_estimate_raw_results.nc" + else: + file_name = f"uncertainty_estimate_raw_results_{season}.nc" + with xr.open_dataarray(os.path.join(data_path, file_name)) as d: + data[season] = d + mean = {} + for season in data.keys(): + mean[season] = data[season].mean(boot_dim) + xr_data = xr.Dataset(mean).to_array(season_dim) + factor = xr_data.sel({season_dim: "total"}) / xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}).sum( + season_dim) + return xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}) * factor + + @staticmethod + def _get_target_sampling(sampling, pos): + sampling = (sampling, sampling) if isinstance(sampling, str) else sampling + sampling_letter = {"hourly": "H", "daily": "d"}.get(sampling[pos], "") + return sampling, sampling_letter + + def _plot(self, dim, split_ahead=True, sampling="daily", orientation="vertical"): + _, sampling_letter = self._get_target_sampling(sampling, 1) + if split_ahead is True: + self.plot_name = self.plot_name_orig + "_total_" + orientation + data = self._data.mean(dim) + if orientation == "vertical": + fig, ax = plt.subplots(1, 1) + data.to_pandas().T.plot.bar(ax=ax, stacked=True, cmap="Dark2", legend=False) + ax.xaxis.label.set_visible(False) + ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})") + else: + m = data.to_pandas().T.shape[0] + fig, ax = plt.subplots(1, 1, figsize=(6, m)) + data.to_pandas().T.plot.barh(ax=ax, stacked=True, cmap="Dark2", legend=False) + ax.yaxis.label.set_visible(False) + ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})") + fig.legend(*ax.get_legend_handles_labels(), loc="upper center", ncol=4) + fig.tight_layout(rect=[0, 0, 1, 0.9]) + else: + self.plot_name = self.plot_name_orig + "_" + orientation + data = self._data + n = len(data.coords[dim]) + if orientation == "vertical": + fig, ax = plt.subplots(1, n, sharey=True) + for i, sel in enumerate(data.coords[dim].values): + data.sel({dim: sel}).to_pandas().T.plot.bar(ax=ax[i], stacked=True, cmap="Dark2", legend=False) + label = str(sel) + sampling_letter + ax[i].set_title(label) + ax[i].xaxis.label.set_visible(False) + ax[0].set_ylabel(f"{self.error_measure} (in {self.error_unit})") + fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4) + fig.tight_layout(rect=[0, 0, 1, 0.9]) + else: + m = data.max(self.season_dim).shape + fig, ax = plt.subplots(n, 1, sharex=True, figsize=(6, np.prod(m) * 0.6)) + for i, sel in enumerate(data.coords[dim].values): + data.sel({dim: sel}).to_pandas().T.plot.barh(ax=ax[i], stacked=True, cmap="Dark2", legend=False) + label = str(sel) + sampling_letter + ax[i].set_title(label) + ax[i].yaxis.label.set_visible(False) + ax[-1].set_xlabel(f"{self.error_measure} (in {self.error_unit})") + fig.legend(*ax[0].get_legend_handles_labels(), loc="upper center", ncol=4) + fig.tight_layout(rect=[0, 0, 1, 0.95]) + + if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] path = "../../testrun_network/forecasts" diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index a48a82b2..d65a2001 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -24,7 +24,7 @@ from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ PlotCompetitiveSkillScore, PlotTimeSeries, PlotFeatureImportanceSkillScore, PlotConditionalQuantiles, \ - PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric + PlotSeparationOfScales, PlotSampleUncertaintyFromBootstrap, PlotTimeEvolutionMetric, PlotSeasonalMSEStack from mlair.plotting.data_insight_plotting import PlotStationMap, PlotAvailability, PlotAvailabilityHistogram, \ PlotPeriodogram, PlotDataHistogram from mlair.run_modules.run_environment import RunEnvironment @@ -681,6 +681,16 @@ class PostProcessing(RunEnvironment): logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}" f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + try: + if "PlotSeasonalMSEStack" in plot_list: + report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") + PlotSeasonalMSEStack(data_path=report_path, plot_folder=self.plot_path, + boot_dim=self.uncertainty_estimate_boot_dim, ahead_dim=self.ahead_dim, + sampling=self._sampling, error_measure="Mean Squared Error", error_unit=r"ppb$^2$") + except Exception as e: + logging.error(f"Could not create plot PlotSeasonalMSEStack due to the following error: {e}" + f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}") + @TimeTrackingWrapper def calculate_test_score(self): """Evaluate test score of model and save locally.""" -- GitLab