diff --git a/video_prediction_tools/postprocess/statistical_evaluation.py b/video_prediction_tools/postprocess/statistical_evaluation.py index 74111b92e32442432437009ca1fd6d81b75e14cb..ad057b455192dccf4b7201277a130d2b747d39d1 100644 --- a/video_prediction_tools/postprocess/statistical_evaluation.py +++ b/video_prediction_tools/postprocess/statistical_evaluation.py @@ -20,6 +20,50 @@ except: da_or_ds = Union[xr.DataArray, xr.Dataset] +def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, factorization="calibration_refinement", + quantiles=(0.05, 0.5, 0.95)): + + method = calculate_cond_quantiles.__name__ + + # sanity checks + if not isinstance(data_fcst, xr.DataArray): + raise ValueError("%{0}: data_fcst must be a DataArray.".format(method)) + + if not isinstance(data_ref, xr.DataArray): + raise ValueError("%{0}: data_ref must be a DataArray.".format(method)) + + if not (data_fcst.coords == data_ref.coords and data_fcst.dims == data_ref.dims): + raise ValueError("%{0}: Coordinates and dimensions of data_fcst and data_ref must be the same".format(method)) + + nquantiles = len(quantiles) + if not nquantiles >= 3: + raise ValueError("%{0}: quantiles must be a list/tuple of at least three float values ([0..1])".format(method)) + + if factorization == "calibration_refinement": + data_cond = data_fcst + data_tar = data_ref + elif factorization == "likelihood-base_rate": + data_cond = data_ref + data_tar = data_fcst + else: + raise ValueError("%{0}: Choose either 'calibration_refinement' or 'likelihood-base_rate' for factorization" + .format(method)) + + # get bins for conditioning + data_cond_min, data_cond_max = np.floor(np.min(data_cond)), np.ceil(np.max(data_cond)) + bins = list(np.arange(int(data_cond_min), int(data_cond_max) + 1)) + bins_c = 0.5 * (np.asarray(bins[0:-1]) + np.asarray(bins[1:])) + nbins = len(bins) - 1 + # initialize quantile data array + quantile_panel = xr.DataArray(np.full((nbins, nquantiles), np.nan), + coords={"bin_center": bins_c, "quantile": quantiles}, dims=["bin_center", "quantile"]) + # fill the quantile data array + for i in np.arange(nbins): + # conditioning of ground truth based on forecast + data_cropped = data_tar.where(np.logical_and(data_cond >= bins[i], data_cond < bins[i + 1])) + # quantile-calculation + quantile_panel.loc[dict(bin_center=bins_c[i])] = data_cropped.quantile(quantiles) + def avg_metrics(metric: da_or_ds, dim_name: str): """ Averages metric over given dimension