diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 22e55220c6aa54f8352435cc5f5ddaf4f072f0b7..da671d98ef07869ccf1e584a349f78129f2e5784 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1397,23 +1397,45 @@ class PlotTimeEvolutionMetric(AbstractPlotClass): @TimeTrackingWrapper class PlotSeasonalMSEStack(AbstractPlotClass): - def __init__(self, data_path: str, plot_folder: str = ".", boot_dim="boots", ahead_dim="ahead", - sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$", - model_name: str = "NN", model_indicator: str = "nn", model_type_dim: str = "type"): + def __init__(self, data, data_path: str, plot_folder: str = ".", boot_dim="boots", ahead_dim="ahead", + sampling: str = "daily", error_measure: str = "MSE", error_unit: str = "ppb$^2$", time_dim="index", + model_type_dim: str = "type", model_name: str = "NN", model_indicator: str = "nn",): """Set attributes and create plot.""" super().__init__(plot_folder, "seasonal_mse_stack_plot") - self.plot_name_orig = "seasonal_mse_stack_plot" self._data_path = data_path self.season_dim = "season" + self.time_dim = time_dim + self.ahead_dim = ahead_dim self.error_unit = error_unit self.error_measure = error_measure - self._data = self._prepare_data(boot_dim, data_path, model_type_dim, model_indicator, model_name) + self.dim_order = [self.season_dim, ahead_dim, model_type_dim] + + # mse from monthly blocks + self.plot_name_orig = "seasonal_mse_stack_plot" + self._data = self._prepare_data(data) for orientation in ["horizontal", "vertical"]: for split_ahead in [True, False]: self._plot(ahead_dim, split_ahead, sampling, orientation) self._save(bbox_inches="tight") - def _prepare_data(self, boot_dim, data_path, model_type_dim, model_indicator, model_name): + # mes from resampling + self.plot_name_orig = "seasonal_mse_from_uncertainty_stack_plot" + self._data = self._prepare_data_from_uncertainty(boot_dim, data_path, model_type_dim, model_indicator, + model_name) + for orientation in ["horizontal", "vertical"]: + for split_ahead in [True, False]: + self._plot(ahead_dim, split_ahead, sampling, orientation) + self._save(bbox_inches="tight") + + def _prepare_data(self, data): + season_mean = data.groupby(f"{self.time_dim}.{self.season_dim}").mean() + total_mean = data.mean(self.time_dim) + factor = season_mean / season_mean.sum(self.season_dim) + season_share = (total_mean * factor).reindex({self.season_dim: ["DJF", "MAM", "JJA", "SON"]}) + season_share = season_share.mean(set(season_share.dims).difference(self.dim_order)) + return season_share.sortby(season_share.sum([self.season_dim, self.ahead_dim])).transpose(*self.dim_order) + + def _prepare_data_from_uncertainty(self, boot_dim, data_path, model_type_dim, model_indicator, model_name): season_dim = self.season_dim data = {} for season in ["total", "DJF", "MAM", "JJA", "SON"]: @@ -1427,11 +1449,11 @@ class PlotSeasonalMSEStack(AbstractPlotClass): for season in data.keys(): mean[season] = data[season].mean(boot_dim) xr_data = xr.Dataset(mean).to_array(season_dim) - factor = xr_data.sel({season_dim: "total"}) / xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}).sum( - season_dim) - xr_data = xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}) * factor xr_data[model_type_dim] = [v if v != model_indicator else model_name for v in xr_data[model_type_dim].values] - return xr_data + xr_season = xr_data.sel({season_dim: ["DJF", "MAM", "JJA", "SON"]}) + factor = xr_season / xr_season.sum(season_dim) + season_share = xr_data.sel({season_dim: "total"}) * factor + return season_share.sortby(season_share.sum([self.season_dim, self.ahead_dim])).transpose(*self.dim_order) @staticmethod def _get_target_sampling(sampling, pos): @@ -1453,7 +1475,7 @@ class PlotSeasonalMSEStack(AbstractPlotClass): def _plot(self, dim, split_ahead=True, sampling="daily", orientation="vertical"): _, sampling_letter = self._get_target_sampling(sampling, 1) - if split_ahead is True: + if split_ahead is False: self.plot_name = self.plot_name_orig + "_total_" + orientation data = self._data.mean(dim) if orientation == "vertical": diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 459a1928958a96238af89caa4241911554df416f..a6f1423fbc6eb3869e4d628f7a8f787dbd1d2cc4 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -91,6 +91,7 @@ class PostProcessing(RunEnvironment): self.uncertainty_estimate = None self.uncertainty_estimate_seasons = {} self.block_mse_per_station = None + self.block_mse = None self.competitor_path = self.data_store.get("competitor_path") self.competitors = to_list(self.data_store.get_default("competitors", default=[])) self.forecast_indicator = "nn" @@ -117,6 +118,10 @@ class PostProcessing(RunEnvironment): # calculate error metrics on test data self.calculate_test_score() + # calculate monthly block mse + self.block_mse, self.block_mse_per_station = self.calculate_block_mse(evaluate_competitors=True, + separate_ahead=True, block_length="1m") + # sample uncertainty if self.data_store.get("do_uncertainty_estimate", "postprocessing"): self.estimate_sample_uncertainty(separate_ahead=True) @@ -156,10 +161,12 @@ class PostProcessing(RunEnvironment): block_length = self.data_store.get_default("block_length", default="1m", scope="uncertainty_estimate") evaluate_competitors = self.data_store.get_default("evaluate_competitors", default=True, scope="uncertainty_estimate") - block_mse, block_mse_per_station = self.calculate_block_mse(evaluate_competitors=evaluate_competitors, - separate_ahead=separate_ahead, - block_length=block_length) - self.block_mse_per_station = block_mse_per_station + if evaluate_competitors is True and separate_ahead is True and block_length == "1m": + block_mse, block_mse_per_station = self.block_mse, self.block_mse_per_station + else: + block_mse, block_mse_per_station = self.calculate_block_mse(evaluate_competitors=evaluate_competitors, + separate_ahead=separate_ahead, + block_length=block_length) estimate = statistics.create_n_bootstrap_realizations( block_mse, dim_name_time=self.index_dim, dim_name_model=self.model_type_dim, dim_name_boots=self.uncertainty_estimate_boot_dim, n_boots=n_boots, seasons=["DJF", "MAM", "JJA", "SON"]) @@ -692,7 +699,7 @@ class PostProcessing(RunEnvironment): try: if "PlotSeasonalMSEStack" in plot_list: report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") - PlotSeasonalMSEStack(data_path=report_path, plot_folder=self.plot_path, + PlotSeasonalMSEStack(data=self.block_mse_per_station, data_path=report_path, plot_folder=self.plot_path, boot_dim=self.uncertainty_estimate_boot_dim, ahead_dim=self.ahead_dim, sampling=self._sampling, error_measure="Mean Squared Error", error_unit=r"ppb$^2$", model_indicator=self.forecast_indicator, model_name=self.model_display_name,