From a6ce89e1abca3422dc525161e3dcbba1f27d200c Mon Sep 17 00:00:00 2001
From: gong1 <b.gong@fz-juelich.de>
Date: Wed, 9 Jun 2021 12:33:13 +0200
Subject: [PATCH] update unitest for quantile plot

---
 test/test_visualize_postprocess.py | 31 ++++++++++++++++++++++++++++--
 1 file changed, 29 insertions(+), 2 deletions(-)

diff --git a/test/test_visualize_postprocess.py b/test/test_visualize_postprocess.py
index d9a3cdc0..3d0408ba 100644
--- a/test/test_visualize_postprocess.py
+++ b/test/test_visualize_postprocess.py
@@ -127,9 +127,36 @@ def test_run_deterministic(vis_case1):
     assert eval_metric_ds[metric_name][0,5] == sample_gen_ref_mse_t5   
 
 
+def test_plot_conditional_quantiles(vis_case1):
+    vis_case1.nun_samples_per_epoch = 20
+    vis_case1.run_deterministic()
+    # the variables for conditional quantile plot
+    var_fcst = vis_case1.cond_quantile_vars[0]
+    var_ref = vis_case1.cond_quantile_vars[1]
+    data_fcst = get_era5_varatts(vis_case1.cond_quantiple_ds[var_fcst], vis_case1.cond_quantiple_ds[var_fcst].name)
+    data_ref = get_era5_varatts(vis_case1.cond_quantiple_ds[var_ref], vis_case1.cond_quantiple_ds[var_ref].name)
+    print("data_fcast",data_fcst)
+    fhhs = data_fcst["fcst_hour"]
+    
+    hh = 1 
+    quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh),
+                                                                           data_ref.sel(fcst_hour=hh),
+                                                                           factorization="calibration_refinement",
+                                                                           quantiles=(0.05, 0.5, 0.95))
 
-       
-
+    
+   
+   
+   data_cond = data_fcst.sel(fcst_hour=hh)  
+   data_tar = data_ref.sel(fcst_hour=hh)
+   data_cond_min, data_cond_max = np.floor(np.min(data_cond)), np.ceil(np.max(data_cond))
+   bins = list(np.arange(int(data_cond_min), int(data_cond_max) + 1))
+   nbins = len(bins) - 1
+   
+   bin_l_1, bin_r_1 = bins[0], bins[1]
+   #find position of the values between bin
+   data_cropped = data_tar.where(np.logical_and(data_cond >= bins_l_1, data_cond < bins_r_l))
+   
 
     
        
-- 
GitLab