diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 0ff4afb37e76c0c03127bd84258ccb5fd2eb4a90..634d887c279c4bad8f8993542bf9534ab7b1dd52 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -87,6 +87,7 @@ class PostProcessing(RunEnvironment): self.competitor_path = self.data_store.get("competitor_path") self.competitors = to_list(self.data_store.get_default("competitors", default=[])) self.forecast_indicator = "nn" + self.observation_indicator = "obs" self.ahead_dim = "ahead" self.boot_var_dim = "boot_var" self.model_type_dim = "type" @@ -215,7 +216,7 @@ class PostProcessing(RunEnvironment): # store also true labels for each station labels = np.expand_dims(Y, axis=-1) file_name = os.path.join(forecast_path, f"bootstraps_{station}_{bootstrap_method}_labels.nc") - labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=dims) + labels = xr.DataArray(labels, coords=(*coords, [self.observation_indicator]), dims=dims) labels.to_netcdf(file_name) def calculate_bootstrap_skill_scores(self, bootstrap_type, bootstrap_method) -> Dict[str, xr.DataArray]: @@ -517,7 +518,7 @@ class PostProcessing(RunEnvironment): full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency()) prediction_dict = {self.forecast_indicator: nn_prediction, "persi": persistence_prediction, - "obs": observation, + self.observation_indicator: observation, "ols": ols_prediction} all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]), time_dimension, ahead_dim=self.ahead_dim, @@ -764,17 +765,22 @@ class PostProcessing(RunEnvironment): # test errors if external_data is not None: - for model_type in remove_items(external_data.coords[self.model_type_dim].values.tolist(), "obs"): + model_type_list = external_data.coords[self.model_type_dim].values.tolist() + for model_type in remove_items(model_type_list, self.observation_indicator): if model_type not in errors.keys(): errors[model_type] = {} errors[model_type][station] = statistics.calculate_error_metrics( *map(lambda x: external_data.sel(**{self.model_type_dim: x}), - [model_type, "obs"]), dim="index") + [model_type, self.observation_indicator]), dim="index") # load competitors competitor = self.load_competitors(station) combined = self._combine_forecasts(external_data, competitor, dim=self.model_type_dim) - model_list = remove_items(combined.coords[self.model_type_dim].values.tolist(), "obs") if combined is not None else None + if combined is not None: + model_list = remove_items(combined.coords[self.model_type_dim].values.tolist(), + self.observation_indicator) + else: + model_list = None # test errors of competitors for model_type in remove_items(model_list, errors.keys()): @@ -782,7 +788,7 @@ class PostProcessing(RunEnvironment): errors[model_type] = {} errors[model_type][station] = statistics.calculate_error_metrics( *map(lambda x: combined.sel(**{self.model_type_dim: x}), - [model_type, "obs"]), dim="index") + [model_type, self.observation_indicator]), dim="index") # skill score skill_score = statistics.SkillScores(combined, models=model_list, ahead_dim=self.ahead_dim)