diff --git a/test/test_visualize_postprocess.py b/test/test_visualize_postprocess.py index d9a3cdc01bd2f803758d090928fb2b752cae8890..3d0408bab293864af689bc4705f7c8cfa5506403 100644 --- a/test/test_visualize_postprocess.py +++ b/test/test_visualize_postprocess.py @@ -127,9 +127,36 @@ def test_run_deterministic(vis_case1): assert eval_metric_ds[metric_name][0,5] == sample_gen_ref_mse_t5 +def test_plot_conditional_quantiles(vis_case1): + vis_case1.nun_samples_per_epoch = 20 + vis_case1.run_deterministic() + # the variables for conditional quantile plot + var_fcst = vis_case1.cond_quantile_vars[0] + var_ref = vis_case1.cond_quantile_vars[1] + data_fcst = get_era5_varatts(vis_case1.cond_quantiple_ds[var_fcst], vis_case1.cond_quantiple_ds[var_fcst].name) + data_ref = get_era5_varatts(vis_case1.cond_quantiple_ds[var_ref], vis_case1.cond_quantiple_ds[var_ref].name) + print("data_fcast",data_fcst) + fhhs = data_fcst["fcst_hour"] + + hh = 1 + quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh), + data_ref.sel(fcst_hour=hh), + factorization="calibration_refinement", + quantiles=(0.05, 0.5, 0.95)) - - + + + + data_cond = data_fcst.sel(fcst_hour=hh) + data_tar = data_ref.sel(fcst_hour=hh) + data_cond_min, data_cond_max = np.floor(np.min(data_cond)), np.ceil(np.max(data_cond)) + bins = list(np.arange(int(data_cond_min), int(data_cond_max) + 1)) + nbins = len(bins) - 1 + + bin_l_1, bin_r_1 = bins[0], bins[1] + #find position of the values between bin + data_cropped = data_tar.where(np.logical_and(data_cond >= bins_l_1, data_cond < bins_r_l)) +