diff --git a/test/test_meta_postprocess.py b/test/test_meta_postprocess.py
index 79883adb09450aebf4d5525f8e15357c7598e58c..1892b0c3c22d723eeba5e08cce238a5177e3860b 100644
--- a/test/test_meta_postprocess.py
+++ b/test/test_meta_postprocess.py
@@ -11,6 +11,11 @@ import pytest
 analysis_config = "/p/home/jusers/gong1/juwels/ambs/video_prediction_tools/analysis_config/analysis_test.json" 
 analysis_dir = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/analysis/bing_test1"
 
+test_nc_fl = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny/vfp_date_2017030118_sample_ind_13.nc"
+test_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2015to2017M01to12-64x64-3930N0000E-T2_MSL_gph500/convLSTM/20201221T181605_gong1_sunny"
+
+
+
 
 #setup instance
 @pytest.fixture(scope="module")
@@ -34,38 +39,41 @@ def test_copy_analysis_config(analysis_inst):
 def test_load_analysis_config(analysis_inst):
     analysis_inst.load_analysis_config()
     metrics_test = analysis_inst.metrics[0]
+    test_dir_read = analysis_inst.results_dirs[0]
     assert metrics_test == "mse"
+    assert test_dir_read == test_dir
 
 
 def test_read_values_by_var_from_nc(analysis_inst):
-    file_nc = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1/vfp_date_2017031312_sample_ind_229.nc"
+    file_nc = test_nc_fl
     real,persistent,forecast,time_forecast  = analysis_inst.read_values_by_var_from_nc(fl_nc = file_nc)
     assert len(real) == len(persistent) == len(forecast) 
     assert len(time_forecast) == len(forecast)
 
 
 def test_calculate_metric_one_dir(analysis_inst):
-    file_nc = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1/vfp_date_2017031312_sample_ind_229.nc"
+    file_nc = test_nc_fl 
     real,persistent,forecast,time_forecast  = analysis_inst.read_values_by_var_from_nc(fl_nc = file_nc)
-    eval_persistent,eval_forecast = analysis_inst.calculate_metric_one_img(real, persistent,forecast,metric="mse")
+    eval_forecast = analysis_inst.calculate_metric_one_img(real,forecast,metric="mse")
 
 def test_load_results_dir_parameters(analysis_inst):
     analysis_inst.load_results_dir_parameters()
     assert len(analysis_inst.compare_by_values) == 2
 
-
 def test_calculate_metric_all_dirs(analysis_inst):
-    analysis_inst.calculate_metrics_all_dirs()
-    assert list(analysis_inst.eval_all.keys())[0] == analysis_inst.results_dirs[0]
-    print(analysis_inst.eval_all["/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1"].keys())
-    assert len(analysis_inst.eval_all["/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1"]["persistent"]["mse"][1]) == 10
+    analysis_inst.calculate_metric_all_dirs()
+
+#def test_calculate_metric_all_dirs(analysis_inst):
+#    analysis_inst.calculate_metrics_all_dirs()
+#    assert list(analysis_inst.eval_all.keys())[0] == analysis_inst.results_dirs[0]
+#    print(analysis_inst.eval_all[test_dir].keys())
+#    assert len(analysis_inst.eval_all[test_dir]["persistent"]["mse"][1]) == 10
 
 
-def test_calculate_mean_vars_forecast(analysis_inst):
-    analysis_inst.calculate_metrics_all_dirs()
-    analysis_inst.calculate_mean_vars_forecast()
-    
-    assert len(analysis_inst.results_dict["/p/home/jusers/gong1/juwels/video_prediction_shared_folder/results/era5_test/convLSTM/20201130T1748_gong1"]["forecast"]) == 2
+#def test_calculate_mean_vars_forecast(analysis_inst):
+#    analysis_inst.calculate_metrics_all_dirs()
+#    analysis_inst.calculate_mean_vars_forecast()
+#    assert len(analysis_inst.results_dict[test_dir]["forecasts"]) == 2
 
 
 def test_plot_results(analysis_inst):
diff --git a/video_prediction_tools/main_scripts/main_meta_postprocess.py b/video_prediction_tools/main_scripts/main_meta_postprocess.py
index 084759765243642edc5f710798e2809688656d13..51f6e19aaa3c0525a6304b38251474d0793a0a8f 100644
--- a/video_prediction_tools/main_scripts/main_meta_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_meta_postprocess.py
@@ -14,7 +14,9 @@ import numpy as np
 import shutil
 import glob
 from netCDF4 import Dataset
-from video_prediction.metrics import *
+from  model_modules.video_prediction.metrics import *
+import xarray as xr
+
 
 class MetaPostprocess(object):
     def __init__(self, analysis_config=None, analysis_dir=None, stochastic_ind=0, forecast_type="deterministic"):
@@ -81,140 +83,72 @@ class MetaPostprocess(object):
            #load var prediction, real and persistent values
            real = fl["/analysis/reference/"].variables[var][:]
            persistent = fl["/analysis/persistent/"].variables[var][:]
-           forecast = fl["/forecast/"+var+"/stochastic"].variables[str(stochastic_ind)][:]
+           forecast = fl["/forecasts/"+var+"/stochastic"].variables[str(stochastic_ind)][:]
            time_forecast = fl.variables["time_forecast"][:]
         return real, persistent, forecast, time_forecast   
 
 
     @staticmethod
-    def calculate_metric_one_img(real, persistent,forecast,metric="mse"):
+    def calculate_metric_one_img(real,forecast,metric="mse"):
         if metric == "mse":
          #compare real and persistent
-            eval_persistent = mse_imgs(real,persistent)
             eval_forecast = mse_imgs(real, forecast)
         elif metric == "psnr":
-            eval_persistent = psnr_imgs(real,forecast)
             eval_forecast = psnr_imgs(real,forecast)   
-        return eval_persistent, eval_forecast
+        return  eval_forecast
     
     @staticmethod
     def reshape_eval_to_one_dim(values):
         return np.array(values).flatten()
 
-    def calculate_metrics_all_dirs(self):
-        """
-        Calculate the all the metrics for persistent and forecast results
-        eval_all is dictionary,
-        eval_all = {
-                     <results_dir>: 
-                                   {
-                                     "persistent":
-                                                 {
-                                                 <metric_name1> : eval_values,
-                                                 <metric_name2> : eval_values
-                                                 } 
-                                   
-                                     "forecast" :
-                                                {
-                                                 <metric_name1> : eval_values,
-                                                 <metric_name2> : eval_values                                                
-                                                }
-                                   
-                                   }
-                    }
-
+    def calculate_metric_all_dirs(self,is_persistent=False,metric="mse"):
         """
-        self.eval_all = {}
-        for results_dir in self.results_dirs:
-            self.eval_all.update({results_dir: {"persistent":None}})
-            self.eval_all.update({results_dir: {"forecast":None}})
-            real_all, persistent_all, forecast_all, self.time_forecast = MetaPostprocess.load_prediction_and_real_from_one_dir(results_dir,var="T2",stochastic_ind=self.stochastic_ind)
-            for metric in self.metrics:
-                self.eval_persistent_all = []
-                self.eval_forecast_all = []
-                #loop for real data
-                for idx in range(len(real_all)):
-                    eval_persistent_per_sample_over_ts = []
-                    eval_forecast_per_sample_over_ts = []
-                    
-                    #loop the forecast time
-                    for time in range(len(self.time_forecast)):
-                        #loop for each sample and each timestamp
-                        self.eval_persistent, self.eval_forecast = MetaPostprocess.calculate_metric_one_img(real_all[idx][time],persistent_all[idx][time],forecast_all[idx][time], metric=metric)
-                        eval_persistent_per_sample_over_ts.append(self.eval_persistent)
-                        eval_forecast_per_sample_over_ts.append(self.eval_forecast)
-                    
-                    self.eval_persistent_all.append(list(eval_persistent_per_sample_over_ts))
-                    self.eval_forecast_all.append(list(eval_forecast_per_sample_over_ts))
-                    #the shape of self.eval_persistent_all is [samples,time_forecast]
-                self.eval_all[results_dir]["persistent"] = {metric: list(self.eval_persistent_all)}           
-                self.eval_all[results_dir]["forecast"] = {metric: list(self.eval_forecast_all)}   
+        Return the evaluation metrics for persistent and forecasing model over forecasting timestampls
         
-    def save_metrics_all_dir_to_json(self):
-        with open("metrics_results.json","w") as f:
-            json.dump(self.eval_all,f)
-
-         
-    def load_results_dir_parameters(self,compare_by="model"):
-        self.compare_by_values = []
-        for results_dir in self.results_dirs:
-            with open(os.path.join(results_dir, "options_checkpoints.json")) as f:
-                self.options = json.loads(f.read())
-                print("self.options:",self.options)
-                #if self.compare_by == "model":
-                self.compare_by_values.append(self.options[compare_by])
-  
-    
-    def calculate_mean_vars_forecast(self):
-        """
-        Calculate the mean varations of persistent and forecast evalaution metrics
+        return:
+               eval_forecast: list, the evaluation metric values for persistent  with respect to the dimenisons [results_dir,samples,timestampe]
+               
         """
-        is_first_persistent = False
+        eval_forecast_all_dirs = []
         for results_dir in self.results_dirs:
-            evals = self.eval_all[results_dir]
-            eval_persistent = evals["persistent"]
-            eval_forecast = evals["forecast"]
-            self.results_dict = {} 
-            for metric in self.metrics:
-                err_stat = []
+            real_all, persistent_all, forecast_all, self.time_forecast = MetaPostprocess.load_prediction_and_real_from_one_dir(results_dir,var="T2",stochastic_ind=self.stochastic_ind)
+            
+            if is_persistent: forecast_all = persistent_all
+            eval_forecast_all = []
+            #loop for real data
+            for idx in range(len(real_all)):
+                eval_forecast_per_sample_over_ts = []
+                #loop the forecast time
                 for time in range(len(self.time_forecast)):
-                    forecast_values_all = list(eval_forecast[metric])[:][time]
-                    persistent_values_all = list(eval_persistent[metric])[:][time]
-                    forecast_mean = np.mean(np.array(forecast_values_all),axis=0)
-                    persistent_mean = np.mean(np.array(persistent_values_all),axis=0)
-                    forecast_vars = np.var(np.array(forecast_values_all),axis=0)
-                    persistent_vars = np.var(np.array(persistent_values_all),axis=0)
-                    #[time,mean,vars]
-                    self.results_dict[results_dir] = {"persistent":[persistent_mean, persistent_vars]} 
-                    self.results_dict[results_dir].update({"forecast":[forecast_mean,forecast_vars]})
-               
+                    #loop for each sample and each timestamp
+                    eval_forecast = MetaPostprocess.calculate_metric_one_img(real_all[idx][time],forecast_all[idx][time], metric=metric)
+                    eval_forecast_per_sample_over_ts.append(eval_forecast)
+
+                eval_forecast_all.append(list(eval_forecast_per_sample_over_ts))
+            eval_forecast_all_dirs.append(eval_forecast_all)
+
+        times = list(range(len(self.time_forecast)))
+        samples = list(range(len(real_all)))
+        print("shape of list",np.array(eval_forecast_all_dirs).shape)
+        evals_forecast = xr.DataArray(eval_forecast_all_dirs, coords=[self.results_dirs, samples , times], dims=["results_dirs", "samples","time_forecast"])
+        return evals_forecast
 
+    
     def plot_results(self,one_persistent=True):
         """
         Plot the mean and vars for the user-defined metrics
         """
-        
-        self.load_results_dir_parameters()
-        is_first_persistent=True
-        mean_all_persistent = []
-        vars_all_persistent = []
-        mean_all_model = []
-        vars_all_model = []
-        for results_dir in self.results_dirs:
-            mean_all_model.append(self.results_dict[results_dir]["forecast"][0])
-            vars_all_model.append(self.results_dict[results_dir]["forecast"][1]) 
-        
-        if one_persistent==True:
-            mean_all_model.append(self.results_dict[results_dir]["persistent"][0])
-            vars_all_model.append(self.results_dict[results_dir]["persistent"][1])
-            self.compare_by_values.append("persistent")
-        
-        
+        self.load_results_dir_parameters(compare_by="model")
+        evals_forecast = self.calculate_metric_all_dirs(is_persistent=False,metric="mse")
+        t = evals_forecast["time_forecast"]
+        mean_forecast = evals_forecast.groupby("time_forecast").mean(dim="samples").values
+        var_forecast = evals_forecast.groupby("time_forecast").var(dim="samples").values
+        print("mean_foreast",mean_forecast)
         x = np.array(self.compare_by_values)
-        y = np.array(mean_all_model)
-        e = np.array(vars_all_model)
-
-        plt.errorbar(x,y,e,linestyle="None",marker='^')
+        y = np.array(mean_forecast)
+        e = np.array(var_forecast)
+       
+        plt.errorbar(t,y[0],e[0],linestyle="None",marker='^')
         plt.show()
         plt.savefig(os.path.join(self.analysis_dir,self.metrics[0]+".png"))
         plt.close()
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index 382cd6a7c82db3b263c28b01d85bc19ecb248cc1..c78eecf87fa704bf06c700825910914f63110094 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -199,8 +199,8 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
             self.test_handle, self.test_tf_dataset.output_types, self.test_tf_dataset.output_shapes)
         self.inputs = self.iterator.get_next()
         self.input_ts = self.inputs["T_start"]
-        if self.dataset == "era5" and self.model == "savp":
-           del self.inputs["T_start"]
+        #if self.dataset == "era5" and self.model == "savp":
+        #   del self.inputs["T_start"]
 
 
     def check_stochastic_samples_ind_based_on_model(self):
@@ -226,19 +226,23 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
         """
         self.input_results = self.sess.run(self.inputs)
         self.input_images = self.input_results["images"]
-        self.t_starts_results = self.sess.run(self.input_ts)
+        self.t_starts_results = self.input_results["T_start"]
         print("t_starts_results:",self.t_starts_results)
         self.t_starts = self.t_starts_results
         #get one seq and the corresponding start time poin
         #self.t_starts = self.input_results["T_start"]
+        self.input_images_denorm_all = []
         for batch_id in range(self.batch_size):
             self.input_images_ = Postprocess.get_one_seq_from_batch(self.input_images,batch_id)
             #Renormalized data for inputs
             ts = Postprocess.generate_seq_timestamps(self.t_starts[batch_id],len_seq=self.sequence_length)
-            self.input_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl,self.input_images_,self.vars_in)
-            assert len(self.input_images_denorm.shape) == 4
-            Postprocess.plot_seq_imgs(imgs = self.input_images_denorm[self.context_frames+1:,:,:,0],lats=self.lats,lons=self.lons,ts=ts[self.context_frames+1:],label="Ground Truth",output_png_dir=self.results_dir)  
-        return self.input_results, self.input_images,self.t_starts
+            input_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl,self.input_images_,self.vars_in)
+            assert len(input_images_denorm.shape) == 4
+            Postprocess.plot_seq_imgs(imgs = input_images_denorm[self.context_frames+1:,:,:,0],lats=self.lats,lons=self.lons,ts=ts[self.context_frames+1:],label="Ground Truth",output_png_dir=self.results_dir)
+            
+            self.input_images_denorm_all.append(list(input_images_denorm))
+
+        return self.input_results, np.array(self.input_images_denorm_all),self.t_starts
 
 
     def run(self):
@@ -253,7 +257,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
             if self.num_samples_per_epoch < self.sample_ind:
                 break
             else:
-                self.input_results, self.input_images, self.t_starts = self.run_and_plot_inputs_per_batch() #run the inputs and plot each sequence images
+                self.input_results, self.input_images_denorm_all, self.t_starts = self.run_and_plot_inputs_per_batch() #run the inputs and plot each sequence images
 
             feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()}
             gen_images_stochastic = [] #[stochastic_ind,batch_size,seq_len,lat,lon,channels]
@@ -293,7 +297,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
             print("persistent_images_per_batch",len(np.array(persistent_images_per_batch)))
             for batch_id in range(self.batch_size):
                 print("batch_id is here",batch_id)
-                self.save_to_netcdf_for_stochastic_generate_images(self.input_images[batch_id], persistent_images_per_batch[batch_id],
+                self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id], persistent_images_per_batch[batch_id],
                                                             np.array(gen_images_stochastic)[:,batch_id,:,:,:,:], 
                                                             fl_name="vfp_date_{}_sample_ind_{}.nc".format(ts_batch[batch_id],self.sample_ind+batch_id))
             
@@ -412,6 +416,7 @@ class Postprocess(TrainModel,ERA5Pkl2Tfrecords):
             gen_images_stochastic: list/array (float), [stochastic_number,seq,lat,lon,channel]
             fl_name              : str, the netcdf file name to be saved
         """
+        print("inputs fpor netcdf:",input_images_)
         assert (len(np.array(input_images_).shape)==len(np.array(gen_images_stochastic).shape))-1
         persistent_images_ = np.array(persistent_images_)
         assert len(persistent_images_.shape) == 4 #[seq,lat,lon,channel]
@@ -709,5 +714,4 @@ def main():
 
 
 if __name__ == '__main__':
-    main() 
-
+    main()