diff --git a/video_prediction_tools/postprocess/statistical_evaluation.py b/video_prediction_tools/postprocess/statistical_evaluation.py index fcb4d8e93a5ad6fe99c0210632e5f7e38df4f2ce..6756c82c98fefdc2ed01d4abea7d21bd5a53ddcc 100644 --- a/video_prediction_tools/postprocess/statistical_evaluation.py +++ b/video_prediction_tools/postprocess/statistical_evaluation.py @@ -15,6 +15,7 @@ try: l_tqdm = True except: l_tqdm = False +from general_utils import provide_default # basic data types da_or_ds = Union[xr.DataArray, xr.Dataset] @@ -56,6 +57,13 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa raise ValueError("%{0}: Choose either 'calibration_refinement' or 'likelihood-base_rate' for factorization" .format(method)) + # get and set some basic attributes + data_cond_longname = provide_default(data_cond.attr, "longname", "conditioning_variable") + data_cond_unit = provide_default(data_cond.attr, "unit", "unknown") + + data_tar_longname = provide_default(data_tar.attr, "longname", "target_variable") + data_tar_unit = provide_default(data_cond.attr, "unit", "unknown") + # get bins for conditioning data_cond_min, data_cond_max = np.floor(np.min(data_cond)), np.ceil(np.max(data_cond)) bins = list(np.arange(int(data_cond_min), int(data_cond_max) + 1)) @@ -63,7 +71,9 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa nbins = len(bins) - 1 # initialize quantile data array quantile_panel = xr.DataArray(np.full((nbins, nquantiles), np.nan), - coords={"bin_center": bins_c, "quantile": quantiles}, dims=["bin_center", "quantile"]) + coords={"bin_center": bins_c, "quantile": quantiles}, dims=["bin_center", "quantile"], + attrs={"cond_var_name": data_cond_longname, "cond_var_unit": data_cond_unit, + "tar_var_name": data_tar_longname, "tar_var_unit": data_tar_unit}) # fill the quantile data array for i in np.arange(nbins): # conditioning of ground truth based on forecast @@ -73,6 +83,7 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa return quantile_panel, data_cond + def avg_metrics(metric: da_or_ds, dim_name: str): """ Averages metric over given dimension