diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 56b5c363f15aa7f40a25bd02392dd9d85bf88396..425b26cfbde6a22b2dd61b87478f00c52ccda87e 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -89,6 +89,7 @@ class PostProcessing(RunEnvironment): self.forecast_indicator = "nn" self.ahead_dim = "ahead" self.boot_var_dim = "boot_var" + self.model_type_dim = "type" self._run() def _run(self): @@ -118,7 +119,7 @@ class PostProcessing(RunEnvironment): skill_score_competitive, skill_score_climatological, errors = self.calculate_error_metrics() self.skill_scores = (skill_score_competitive, skill_score_climatological) self.report_error_metrics(errors) - self.report_error_metrics(skill_score_climatological) + self.report_error_metrics({self.forecast_indicator: skill_score_climatological}) # plotting self.plot() @@ -142,7 +143,7 @@ class PostProcessing(RunEnvironment): except (FileNotFoundError, KeyError): logging.debug(f"No competitor found for combination '{station_name}' and '{competitor_name}'.") continue - return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None + return xr.concat(competing_predictions, self.model_type_dim) if len(competing_predictions) > 0 else None def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type="singleinput", bootstrap_method="shuffle") -> None: @@ -190,7 +191,7 @@ class PostProcessing(RunEnvironment): # extract all requirements from data store forecast_path = self.data_store.get("forecast_path") number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing") - dims = ["index", self.ahead_dim, "type"] + dims = ["index", self.ahead_dim, self.model_type_dim] for station in self.test_data: X, Y = None, None bootstraps = BootStraps(station, number_of_bootstraps, bootstrap_type=bootstrap_type, @@ -250,7 +251,7 @@ class PostProcessing(RunEnvironment): orig = self.get_orig_prediction(forecast_path, forecast_file % str(station), number_of_bootstraps) orig = orig.reshape(shape) coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"]) - orig = xr.DataArray(orig, coords=coords, dims=["index", self.ahead_dim, "type"]) + orig = xr.DataArray(orig, coords=coords, dims=["index", self.ahead_dim, self.model_type_dim]) # calculate skill scores for each variable skill = pd.DataFrame(columns=range(1, self.window_lead_time + 1)) @@ -549,7 +550,7 @@ class PostProcessing(RunEnvironment): with xr.open_dataarray(file) as da: data = da.load() forecast = data.sel(type=[self.forecast_indicator]) - forecast.coords["type"] = [competitor_name] + forecast.coords[self.model_type_dim] = [competitor_name] return forecast def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray: @@ -731,12 +732,13 @@ class PostProcessing(RunEnvironment): except (IndexError, KeyError, FileNotFoundError): return None - @staticmethod - def _combine_forecasts(forecast, competitor, dim="type"): + def _combine_forecasts(self, forecast, competitor, dim=None): """ Combine forecast and competitor if both are xarray. If competitor is None, this returns forecasts and vise versa. """ + if dim is None: + dim = self.model_type_dim try: return xr.concat([forecast, competitor], dim=dim) except (TypeError, AttributeError): @@ -762,13 +764,17 @@ class PostProcessing(RunEnvironment): # test errors if external_data is not None: - errors[station] = statistics.calculate_error_metrics(*map(lambda x: external_data.sel(type=x), - [self.forecast_indicator, "obs"]), - dim="index") + for model_type in remove_items(external_data.coords[self.model_type_dim].values.tolist(), "obs"): + if model_type not in errors.keys(): + errors[model_type] = {} + errors[model_type][station] = statistics.calculate_error_metrics( + *map(lambda x: external_data.sel(**{self.model_type_dim: x}), + [model_type, "obs"]), dim="index") + # skill score competitor = self.load_competitors(station) - combined = self._combine_forecasts(external_data, competitor, dim="type") - model_list = remove_items(list(combined.type.values), "obs") if combined is not None else None + combined = self._combine_forecasts(external_data, competitor, dim=self.model_type_dim) + model_list = remove_items(combined.coords[self.model_type_dim].values.tolist(), "obs") if combined is not None else None skill_score = statistics.SkillScores(combined, models=model_list, ahead_dim=self.ahead_dim) if external_data is not None: skill_score_competitive[station] = skill_score.skill_scores() @@ -778,7 +784,8 @@ class PostProcessing(RunEnvironment): skill_score_climatological[station] = skill_score.climatological_skill_scores( internal_data, forecast_name=self.forecast_indicator) - errors.update({"total": self.calculate_average_errors(errors)}) + for model_type in errors.keys(): + errors[model_type].update({"total": self.calculate_average_errors(errors[model_type])}) return skill_score_competitive, skill_score_climatological, errors @staticmethod @@ -796,7 +803,7 @@ class PostProcessing(RunEnvironment): """Create a csv file containing all results from bootstrapping.""" report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path_config.check_path_and_create(report_path) - res = [["type", "method", "station", self.boot_var_dim, self.ahead_dim, "vals"]] + res = [[self.model_type_dim, "method", "station", self.boot_var_dim, self.ahead_dim, "vals"]] for boot_type, d0 in results.items(): for boot_method, d1 in d0.items(): for station_name, vals in d1.items(): @@ -812,25 +819,26 @@ class PostProcessing(RunEnvironment): def report_error_metrics(self, errors): report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") path_config.check_path_and_create(report_path) - metric_collection = {} - for station, station_errors in errors.items(): - if isinstance(station_errors, xr.DataArray): - dim = station_errors.dims[0] - sel_index = [sel for sel in station_errors.coords[dim] if "CASE" in str(sel)] - station_errors = {str(i.values): station_errors.sel(**{dim: i}) for i in sel_index} - for metric, vals in station_errors.items(): - if metric == "n": - continue - pd_vals = pd.DataFrame.from_dict({station: vals}).T - pd_vals.columns = [f"{metric}(t+{x})" for x in vals.coords["ahead"].values] - mc = metric_collection.get(metric, pd.DataFrame()) - mc = mc.append(pd_vals) - metric_collection[metric] = mc - for metric, error_df in metric_collection.items(): - df = error_df.sort_index() - if "total" in df.index: - df.reindex(df.index.drop(["total"]).to_list() + ["total"], ) - column_format = tables.create_column_format_for_tex(df) - file_name = f"error_report_{metric}.%s".replace(' ', '_') - tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df) - tables.save_to_md(report_path, file_name % "md", df=df) + for model_type in errors.keys(): + metric_collection = {} + for station, station_errors in errors[model_type].items(): + if isinstance(station_errors, xr.DataArray): + dim = station_errors.dims[0] + sel_index = [sel for sel in station_errors.coords[dim] if "CASE" in str(sel)] + station_errors = {str(i.values): station_errors.sel(**{dim: i}) for i in sel_index} + for metric, vals in station_errors.items(): + if metric == "n": + continue + pd_vals = pd.DataFrame.from_dict({station: vals}).T + pd_vals.columns = [f"{metric}(t+{x})" for x in vals.coords["ahead"].values] + mc = metric_collection.get(metric, pd.DataFrame()) + mc = mc.append(pd_vals) + metric_collection[metric] = mc + for metric, error_df in metric_collection.items(): + df = error_df.sort_index() + if "total" in df.index: + df.reindex(df.index.drop(["total"]).to_list() + ["total"], ) + column_format = tables.create_column_format_for_tex(df) + file_name = f"error_report_{metric}_{model_type}.%s".replace(' ', '_') + tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df) + tables.save_to_md(report_path, file_name % "md", df=df)