diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 1c9ac32ddb294a9cf29fe942fd8eae04e9d19d44..972114244635485c72044a0856b4a28a132e9609 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -12,6 +12,7 @@ import traceback import copy from typing import Dict, Tuple, Union, List, Callable import ensverif +import glob import numpy as np import pandas as pd @@ -22,7 +23,8 @@ import tensorflow as tf from mlair.configuration import path_config from mlair.data_handler import Bootstraps, KerasIterator from mlair.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope -from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables +from mlair.helpers import TimeTracking, TimeTrackingWrapper, statistics, extract_value, remove_items, to_list, tables, \ + get_sampling from mlair.model_modules.linear_model import OrdinaryLeastSquaredModel from mlair.model_modules import AbstractModelClass from mlair.plotting.postprocessing_plotting import PlotMonthlySummary, PlotClimatologicalSkillScore, \ @@ -120,6 +122,11 @@ class PostProcessing(RunEnvironment): # calculate error metrics on test data self.calculate_test_score() + # calc/report ens scores + if self.num_realizations is not None: + self.report_crps("test") + + # sample uncertainty if self.data_store.get("do_uncertainty_estimate", "postprocessing"): self.estimate_sample_uncertainty(separate_ahead=True) @@ -147,6 +154,57 @@ class PostProcessing(RunEnvironment): # plotting self.plot() + + @TimeTrackingWrapper + def report_crps(self, subset): + """ + Calculate CRPS for all lead times + :return: + :rtype: + """ + file_pattern = os.path.join(self.forecast_path, f"forecasts_*_ens_{subset}_values.nc") + # get ens files with predictions (not normalized) + ens_files = [e for e in filter(lambda x: not "_norm" in x, glob.glob(file_pattern))] + + ds = xr.open_mfdataset(ens_files) + crps = {} + crps_times = {} + for i in range(1, self.window_lead_time+1): + ens = ds["ens"].sel( + {self.ahead_dim: i, self.ens_moment_dim: "ens_dist_mean", + self.model_type_dim: "ens"} + ).dropna(self.index_dim) + obs = ds["det"].sel( + {self.ahead_dim: i, self.model_type_dim: "obs"} + ).dropna(self.index_dim) + crps[f"{i}{get_sampling(self._sampling)}"] = ensverif.crps.crps( + ens.values.reshape(-1, self.num_realizations), obs.values.reshape(-1), distribution="emp") + crps_stations = {} + for station in ens.coords[self.iter_dim].values: + ens_station = ens.sel({self.iter_dim: station}).values + obs_station = obs.sel({self.iter_dim: station}).values + crps_stations[station] = ensverif.crps.crps(ens_station, obs_station, distribution="emp") + crps_times[f"{i}{get_sampling(self._sampling)}"] = crps_stations + + df_tot = pd.DataFrame(crps, index=[subset]) + df_stations = pd.DataFrame(crps_times) + + report_path = os.path.join(self.data_store.get("experiment_path"), "latex_report") + path_config.check_path_and_create(report_path) + self.store_crps_reports(df_tot, report_path, subset, station=False) + self.store_crps_reports(df_stations, report_path, subset, station=True) + + @staticmethod + def store_crps_reports(df, report_path, subset, station=False): + if station is True: + file_name = f"crps_stations_{subset}.%s" + else: + file_name = f"crps_summary_{subset}.%s" + column_format = tables.create_column_format_for_tex(df) + tables.save_to_tex(report_path, file_name % "tex", column_format=column_format, df=df) + tables.save_to_md(report_path, file_name % "md", df=df) + df.to_csv(file_name % "csv", sep=";") + @TimeTrackingWrapper def estimate_sample_uncertainty(self, separate_ahead=False): """