diff --git a/test/run_pytest.sh b/test/run_pytest.sh
index 83220d34a51379e93add931ae6e03e9491b5bce4..6aae33cf0312455efbaa95bbd491440e6d672b2e 100644
--- a/test/run_pytest.sh
+++ b/test/run_pytest.sh
@@ -2,7 +2,7 @@
 
 # Name of virtual environment 
 #VIRT_ENV_NAME="vp_new_structure"
-VIRT_ENV_NAME="juwels_env"
+VIRT_ENV_NAME="env_hdfml"
 
 if [ -z ${VIRTUAL_ENV} ]; then
    if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then
@@ -24,7 +24,7 @@ fi
 source ../video_prediction_tools/env_setup/modules_train.sh
 ##Test for preprocess moving mnist
 #python -m pytest test_prepare_moving_mnist_data.py
-python -m pytest test_train_moving_mnist_data.py 
+#python -m pytest test_train_moving_mnist_data.py 
 #Test for process step2
 #python -m pytest test_data_preprocess_step2.py
 #python -m pytest test_era5_data.py
@@ -33,5 +33,5 @@ python -m pytest test_train_moving_mnist_data.py
 #rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/* 
 #python -m pytest test_train_model_era5.py
 #python -m pytest test_vanilla_vae_model.py
-#python -m pytest test_visualize_postprocess.py
+python -m pytest test_visualize_postprocess.py
 #python -m pytest test_meta_postprocess.py
diff --git a/test/test_visualize_postprocess.py b/test/test_visualize_postprocess.py
index aebcda3754dd3599c1da1c7c5b0cec1df8364b94..3d0408bab293864af689bc4705f7c8cfa5506403 100644
--- a/test/test_visualize_postprocess.py
+++ b/test/test_visualize_postprocess.py
@@ -7,17 +7,17 @@ from main_scripts.main_visualize_postprocess import *
 import pytest
 import numpy as np
 import datetime
-
+from netCDF4 import Dataset, date2num
 
 ########Test case 1################
 results_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210324T120926_ji4_savp_cv12" 
 checkpoint = "/p/project/deepacf/deeprain/video_prediction_shared_folder/models/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210324T120926_ji4_savp_cv12" 
 mode = "test"
 batch_size = 2
-num_samples = 16
 num_stochastic_samples = 2
 gpu_mem_frac = 0.5
-seed=12345
+seed = 12345
+eval_metrics=["mse", "psnr"]
 
 
 class MyClass:
@@ -31,28 +31,17 @@ args = MyClass(results_dir)
 @pytest.fixture(scope="module")
 def vis_case1():
     return Postprocess(results_dir=results_dir,checkpoint=checkpoint,
-                       mode=mode,batch_size=batch_size,num_samples=num_samples,num_stochastic_samples=num_stochastic_samples,
-                       gpu_mem_frac=gpu_mem_frac,seed=seed,args=args)
-######instance2
-num_samples2 = 200000
-@pytest.fixture(scope="module")
-def vis_case2():
-    return Postprocess(results_dir=results_dir,checkpoint=checkpoint,
-                       mode=mode,batch_size=batch_size,num_samples=num_samples2,num_stochastic_samples=num_stochastic_samples,
-                       gpu_mem_frac=gpu_mem_frac,seed=seed,args=args)
+                       mode=mode,batch_size=batch_size, 
+                       num_stochastic_samples=num_stochastic_samples,
+                       seed=seed,args=args,eval_metrics=eval_metrics)
 
 def test_load_jsons(vis_case1):
-    vis_case1.set_seed()
-    vis_case1.save_args_to_option_json()
-    vis_case1.copy_data_model_json()
-    vis_case1.load_jsons()
     assert vis_case1.dataset == "era5"
     assert vis_case1.model == "savp"
     assert vis_case1.input_dir_tfr == "/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/tfrecords_seq_len_24"
     assert vis_case1.run_mode == "deterministic"
 
 def test_get_metadata(vis_case1):
-    vis_case1.get_metadata()
     assert vis_case1.height == 56
     assert vis_case1.width == 92
     assert vis_case1.vars_in[0] == "2t"
@@ -60,70 +49,123 @@ def test_get_metadata(vis_case1):
 
 
 def test_setup_test_dataset(vis_case1):
-    vis_case1.setup_test_dataset()
     vis_case1.test_dataset.mode == mode
-  
-#def test_copy_data_model_json(vis_case1):
-#    vis_case1.copy_data_model_json()
-#    isfile_copy = os.path.isfile(os.path.join(checkpoint,"options.json"))
-#    assert isfile_copy == True
-#    isfile_copy_model_hpamas = os.path.isfile(os.path.join(checkpoint,"model_hparams.json"))
-#    assert isfile_copy_model_hpamas == True
-
-
-def test_setup_num_samples_per_epoch(vis_case1):
-    vis_case1.setup_test_dataset()
-    vis_case1.setup_num_samples_per_epoch()  
-    assert vis_case1.num_samples_per_epoch == num_samples
-    
+
 def test_get_data_params(vis_case1):
-    vis_case1.get_data_params()
     assert vis_case1.context_frames == 12
     assert vis_case1.future_length == 12
 
 def test_run_deterministic(vis_case1):
-    vis_case1()
+    vis_case1.num_samples_per_epoch = 20
     vis_case1.init_session()
     vis_case1.restore(vis_case1.sess,vis_case1.checkpoint)
-    vis_case1.sample_ind = 0
-    vis_case1.init_eval_metrics_list()
-    vis_case1.input_results,vis_case1.input_images_denorm_all, vis_case1.t_starts = vis_case1.run_and_plot_inputs_per_batch() 
-    assert len(vis_case1.t_starts_results) == batch_size
-    ts_1 = vis_case1.t_starts[0][0]
+    print("fcast-product",vis_case1.fcst_products)
+    eval_metric_ds = Postprocess.init_metric_ds(vis_case1.fcst_products, vis_case1.eval_metrics, vis_case1.vars_in[vis_case1.channel], vis_case1.num_samples_per_epoch, vis_case1.future_length)
+
+    input_results,input_images_denorm_all,t_starts = vis_case1.get_input_data_per_batch(vis_case1.inputs) 
+    assert len(t_starts) == batch_size
+    ts_1 = t_starts[0][0]
     year = str(ts_1)[:4]
     month = str(ts_1)[4:6]
     filename = "ecmwf_era5_" +  str(ts_1)[2:] + ".nc"
-    fl = os.path.join("/p/project/deepacf/deeprain/video_prediction_shared_folder/extractedData",year,month,filename)
+    fl = os.path.join("/p/scratch/deepacf/deeprain/ambs_era5/extractedData",year, month, filename)
     print("netCDF file name:",fl)
     with Dataset(fl,"r")  as data_file:
        t2_var = data_file.variables["2t"][0,:,:]
     t2_var = np.array(t2_var)    
     t2_max = np.max(t2_var[117:173,0:92])
     t2_min = np.min(t2_var[117:173,0:92])
-    input_image = np.array(vis_case1.input_images_denorm_all)[0,0,:,:,0] #get the first batch id and 1st sequence image
+    input_image = np.array(input_images_denorm_all)[0,0,:,:,0] #get the first batch id and 1st sequence image
     input_img_max = np.max(input_image)
     input_img_min = np.min(input_image)
     print("input_image",input_image[0,:10])
     assert t2_max == input_img_max
-    assert t2_min ==  input_img_min
-   
-    feed_dict = {input_ph: vis_case1.input_results[name] for name, input_ph in vis_case1.inputs.items()}
+    assert t2_min == input_img_min
+    sample_ind = 0 
+    feed_dict = {input_ph: input_results[name] for name, input_ph in vis_case1.inputs.items()}
     gen_images = vis_case1.sess.run(vis_case1.video_model.outputs['gen_images'], feed_dict=feed_dict)
+    gen_images_denorm = vis_case1.denorm_images_all_channels(gen_images, vis_case1.vars_in, vis_case1.norm_cls,
+                                                                norm_method="minmax")
     ############Test persistenct value#############
-    vis_case1.ts = Postprocess.generate_seq_timestamps(vis_case1.t_starts[0], len_seq=vis_case1.sequence_length)
-    vis_case1.get_and_plot_persistent_per_sample(sample_id=0)
-    ts_1_per = (datetime.datetime.strptime(str(ts_1), '%Y%m%d%H') - datetime.timedelta(hours=23)).strftime("%Y%m%d%H")
+    times_0, init_times = vis_case1.get_init_time(t_starts)
+    batch_ds = vis_case1.create_dataset(input_images_denorm_all, gen_images_denorm, init_times)
+    nbs = np.minimum(vis_case1.batch_size, vis_case1.num_samples_per_epoch - sample_ind)
+  
+    times_seq = (pd.date_range(times_0[0], periods=int(vis_case1.sequence_length), freq="h")).to_pydatetime() 
+    persistence_seq, _ = Postprocess.get_persistence(times_seq, vis_case1.input_dir_pkl)
+    ts_1_per = (pd.to_datetime(times_0[0]) -  datetime.timedelta(hours=23)).strftime("%Y%m%d%H")
+    
     year_per = str(ts_1_per)[:4]
     month_per = str(ts_1_per)[4:6]
     filename_per = "ecmwf_era5_" +  str(ts_1_per)[2:] + ".nc"
-    fl_per = os.path.join("/p/project/deepacf/deeprain/video_prediction_shared_folder/extractedData",year_per,month_per,filename_per)
+ 
+    fl_per = os.path.join("/p/scratch/deepacf/deeprain/ambs_era5/extractedData",year_per,month_per,filename_per)
     with Dataset(fl_per,"r")  as data_file:
-       t2_var_per = data_file.variables["2t"][0,117:173,0:92]    
+        t2_var_per = data_file.variables["2t"][0,117:173,0:92]    
      
     t2_per_var = np.array(t2_var_per)
     t2_per_max = np.max(t2_per_var)
-    per_image_max = np.max(vis_case1.persistence_images[0])
+    per_image_max = np.max(persistence_seq[0])
     assert t2_per_max == per_image_max
+    
+
+    ##Test evaluation metric
+    for ivar, var in enumerate(vis_case1.vars_in):
+        batch_ds["{0}_persistence_fcst".format(var)].loc[dict(init_time=init_times[0])] = \
+                        persistence_seq[vis_case1.context_frames-1:, :, :, ivar]
+        
+    eval_metric_ds = vis_case1.populate_eval_metric_ds(eval_metric_ds,batch_ds,sample_ind,vis_case1.vars_in[vis_case1.channel])
+    ##now manuly calculate the mse and see if values is the same as the ones in eval_metric_ds
+    #calculate the mse between generateed images and reference images
+    sample_gen = gen_images_denorm[0,vis_case1.context_frames-1:,:,:,vis_case1.channel]  
+    sample_ref = input_images_denorm_all[0,vis_case1.context_frames:,:,:,vis_case1.channel]
+    sample_gen_ref_mse_t0 = np.mean((sample_gen[0] - sample_ref[0])**2)
+    metric_name = "2t_savp_mse"
+    print("eval_metric_ds",eval_metric_ds)
+    assert eval_metric_ds[metric_name][0,0] == sample_gen_ref_mse_t0
+    sample_gen_ref_mse_t5 = np.mean((sample_gen[5] - sample_ref[5])**2)
+    assert eval_metric_ds[metric_name][0,5] == sample_gen_ref_mse_t5   
+
+
+def test_plot_conditional_quantiles(vis_case1):
+    vis_case1.nun_samples_per_epoch = 20
+    vis_case1.run_deterministic()
+    # the variables for conditional quantile plot
+    var_fcst = vis_case1.cond_quantile_vars[0]
+    var_ref = vis_case1.cond_quantile_vars[1]
+    data_fcst = get_era5_varatts(vis_case1.cond_quantiple_ds[var_fcst], vis_case1.cond_quantiple_ds[var_fcst].name)
+    data_ref = get_era5_varatts(vis_case1.cond_quantiple_ds[var_ref], vis_case1.cond_quantiple_ds[var_ref].name)
+    print("data_fcast",data_fcst)
+    fhhs = data_fcst["fcst_hour"]
+    
+    hh = 1 
+    quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh),
+                                                                           data_ref.sel(fcst_hour=hh),
+                                                                           factorization="calibration_refinement",
+                                                                           quantiles=(0.05, 0.5, 0.95))
+
+    
+   
+   
+   data_cond = data_fcst.sel(fcst_hour=hh)  
+   data_tar = data_ref.sel(fcst_hour=hh)
+   data_cond_min, data_cond_max = np.floor(np.min(data_cond)), np.ceil(np.max(data_cond))
+   bins = list(np.arange(int(data_cond_min), int(data_cond_max) + 1))
+   nbins = len(bins) - 1
+   
+   bin_l_1, bin_r_1 = bins[0], bins[1]
+   #find position of the values between bin
+   data_cropped = data_tar.where(np.logical_and(data_cond >= bins_l_1, data_cond < bins_r_l))
+   
+
+    
+       
+
+#def test_run_determinstic_quantile_plot(vis_case1):
+#    vis_case1.init_metric_ds()
+
+
+
 #def test_make_test_dataset_iterator(vis_case1):
 #    vis_case1.make_test_dataset_iterator()
 #    pass
@@ -159,7 +201,4 @@ def test_run_deterministic(vis_case1):
 
 
 
-############Test case 2##################
-
-
 
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index cc580b3e59e5aeed8e12aba2f4cccce2c870e92e..e7d982c12acaddb9352299240181152ef880e522 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -779,8 +779,8 @@ class Postprocess(TrainModel):
         Postprocess.clean_obj_attribute(self, "eval_metrics_ds")
 
         # the variables for conditional quantile plot
-        var_fcst = "{0}_{1}_fcst".format(self.vars_in[self.channel], self.model)
-        var_ref = "{0}_ref".format(self.vars_in[self.channel])
+        var_fcst = self.cond_quantile_vars[0]
+        var_ref = self.cond_quantile_vars[1]
 
         data_fcst = get_era5_varatts(self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name)
         data_ref = get_era5_varatts(self.cond_quantiple_ds[var_ref], self.cond_quantiple_ds[var_ref].name)
@@ -1091,7 +1091,7 @@ class Postprocess(TrainModel):
         if not set(varnames).issubset(ds_in.data_vars):
             raise ValueError("%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(method,
                                                                                                        varnames_str))
-
+        #Bing : why using dtype as an aurument since it seems you only want ton configure dtype as np.double
         if dtype is None:
             dtype = np.double
         else:
diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py
index 8e003e4533bfb62a943d6b6b9f8d13513558bb88..5d0fb397ef4be4d9c1e8bd3aad6d80bdc1a9925b 100644
--- a/video_prediction_tools/utils/general_utils.py
+++ b/video_prediction_tools/utils/general_utils.py
@@ -12,7 +12,6 @@ Provides:   * get_unique_vars
 
 # import modules
 import os
-import sys
 import numpy as np
 import xarray as xr
 
@@ -138,7 +137,7 @@ def check_dir(path2dir: str, lcreate=False):
     if (path2dir is None) or not isinstance(path2dir, str):
         raise ValueError("%{0}: path2dir must be a string defining a pat to a directory.".format(method))
 
-    if os.path.isdir(path2dir):
+    elif os.path.isdir(path2dir):
         return True
     else:
         if lcreate:
@@ -177,7 +176,7 @@ def provide_default(dict_in, keyname, default=None, required=False):
         return dict_in[keyname]
 
 
-def get_era5_varatts(data_arr: xr.DataArray, name):
+def get_era5_varatts(data_arr: xr.DataArray, name: str):
     """
     Writes longname and unit to data arrays given their name is known
     :param data_arr: the data array