diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 972114244635485c72044a0856b4a28a132e9609..59abe4d6c20bfaca8aadc4b972ba10b6934657c4 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -156,7 +156,7 @@ class PostProcessing(RunEnvironment): @TimeTrackingWrapper - def report_crps(self, subset): + def report_crps2(self, subset): """ Calculate CRPS for all lead times :return: @@ -194,6 +194,74 @@ class PostProcessing(RunEnvironment): self.store_crps_reports(df_tot, report_path, subset, station=False) self.store_crps_reports(df_stations, report_path, subset, station=True) + def report_crps(self, subset): + """ + Calculate CRPS for all lead times + :return: + :rtype: + """ + file_pattern = os.path.join(self.forecast_path, f"forecasts_*_ens_{subset}_values.nc") + # get ens files with predictions (not normalized) + # ens_files = [e for e in filter(lambda x: not "_norm" in x, glob.glob(file_pattern))] + + # ds = xr.open_mfdataset(ens_files) + collector_ens = [] + collector_obs = [] + + + + + # crps[f"{i}{get_sampling(self._sampling)}"] = ensverif.crps.crps( + # ens.values.reshape(-1, self.num_realizations), obs.values.reshape(-1), distribution="emp") + crps_stations = {} + idx_counter = 0 + for station in self.test_data.keys(): + station_based_file_name = os.path.join(self.forecast_path, f"forecasts_{station}_ens_{subset}_values.nc") + ds = xr.open_mfdataset(station_based_file_name) + ens_station = ds["ens"].sel({self.iter_dim: station, self.ens_moment_dim: "ens_dist_mean", + self.model_type_dim: "ens"}).dropna(self.index_dim) + obs_station = ds["det"].sel({self.model_type_dim: "obs"}).dropna(self.index_dim) + + new_index = range(idx_counter, idx_counter+len(ens_station[self.index_dim])) + ens_reindex = xr.DataArray(data=ens_station.data, + dims=[self.index_dim, self.ens_realization_dim, self.ahead_dim], + coords={self.index_dim: new_index, + self.ens_realization_dim: ens_station.coords[self.ens_realization_dim], + self.ahead_dim: ens_station.coords[self.ahead_dim]}) + obs_reindex = xr.DataArray(data=obs_station.data, + dims=[self.index_dim, self.ahead_dim], + coords={self.index_dim: new_index, + self.ahead_dim: ens_station.coords[self.ahead_dim]}) + collector_ens.append(ens_reindex) + collector_obs.append(obs_reindex) + idx_counter = new_index[-1] + + crps_times = self._calc_crps_for_lead_times(ens_station, obs_station) + + crps_stations[station] = crps_times + + + df_stations = pd.DataFrame(crps_stations) + report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") + path_config.check_path_and_create(report_path) + self.store_crps_reports(df_stations, report_path, subset, station=True) + + full_ens = xr.concat(collector_ens, dim=self.index_dim) + full_obs = xr.concat(collector_obs, dim=self.index_dim) + df_tot = pd.DataFrame(self._calc_crps_for_lead_times(full_ens, full_obs), index=to_list(subset)) + self.store_crps_reports(df_tot, report_path, subset, station=False) + + + def _calc_crps_for_lead_times(self, ens, obs): + crps_collector = {} + for i in range(1, self.window_lead_time + 1): + ens_res = ens.sel({self.ahead_dim: i, }) + obs_res = obs.sel({self.ahead_dim: i, }) + crps_collector[f"{i}{get_sampling(self._sampling)}"] = ensverif.crps.crps(ens_res, obs_res, + distribution="emp") + + return crps_collector + @staticmethod def store_crps_reports(df, report_path, subset, station=False): if station is True: @@ -1092,7 +1160,7 @@ class PostProcessing(RunEnvironment): keys = list(kwargs.keys()) res_coords = [index.index, ahead_names, keys] res_dims = [index_dim, ahead_dim, type_dim] - res_fill_shape = (len(index.index), len(ahead_names), len(keys)) + # res_fill_shape = (len(index.index), len(ahead_names), len(keys)) if (ens_coords is not None) and (ens_dims is not None): ens_coords = to_list(ens_coords) ens_dims = to_list(ens_dims)