diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
index fb61f5f123fd55bcfc8a44652843ee70ba4dfcb8..0e250b47df28d115c8cdfc77fc708eab5e094ce6 100644
--- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
+++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
@@ -109,7 +109,6 @@ def setup_dirs(input_dir,results_png_dir):
     print ("temporal_dir:",temporal_dir)
 
 
-    
 def update_hparams_dict(model_hparams_dict,dataset):
     hparams_dict = dict(model_hparams_dict)
     hparams_dict.update({
@@ -118,7 +117,7 @@ def update_hparams_dict(model_hparams_dict,dataset):
         'repeat': dataset.hparams.time_shift,
     })
     return hparams_dict
-    
+
 
 def psnr(img1, img2):
     mse = np.mean((img1 - img2) ** 2)
@@ -159,6 +158,7 @@ def write_params_to_results_dir(args,output_dir,dataset,model):
         f.write(json.dumps(model.hparams.values(), sort_keys = True, indent = 4))
     return None
 
+
 def denorm_images(stat_fl, input_images_,channel,var):
     norm_cls  = Norm_data(var)
     norm = 'minmax'
@@ -176,11 +176,12 @@ def denorm_images_all_channels(stat_fl,input_images_,*args):
         print("args c:", args[c])
         input_images_all_channles_denorm.append(denorm_images(stat_fl,input_images_,channel=c,var=args[c]))           
     input_images_denorm = np.stack(input_images_all_channles_denorm, axis=-1)
-    print("input_images_denorm shape",input_images_denorm.shape)
+    #print("input_images_denorm shape",input_images_denorm.shape)
     return input_images_denorm
 
 def get_one_seq_and_time(input_images,t_starts,i):
-    input_images_ = input_images[i, :]
+    assert (len(np.array(input_images).shape)==5)
+    input_images_ = input_images[i,:,:,:,:]
     t_start = t_starts[i]
     return input_images_,t_start
 
@@ -194,19 +195,20 @@ def generate_seq_timestamps(t_start,len_seq=20):
     
     
 def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,ts,context_frames,future_length,model_name,fl_name="test.nc"):
+    assert (len(np.array(input_images_).shape)==len(np.array(gen_images_).shape))
     
     y_len = len(lats)
     x_len = len(lons)
     ts_len = len(ts)
     ts_input = ts[:context_frames]
     ts_forecast = ts[context_frames:]
-    print("context_frame:",context_frames)
-    print("future_frame",future_length)
-    print("length of ts input:",len(ts_input))
-  
+    #print("context_frame:",context_frames)
+    #print("future_frame",future_length)
+    #print("length of ts input:",len(ts_input))
+
     print("input_images_ shape in netcdf,",input_images_.shape)
     gen_images_ = np.array(gen_images_)
-    
+
     output_file = os.path.join(output_dir,fl_name)
     with Dataset(output_file, "w", format="NETCDF4") as nc_file:
         nc_file.title = 'ERA5 hourly reanalysis data and the forecasting data by deep learning for 2-m above sea level temperatures'
@@ -281,7 +283,6 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t
         gph500_r.units = 'm'
         gph500_r[:,:,:] = input_images_[context_frames:,:,:,2]
         
-        
 
         ################ forecast group  #####################
 
@@ -290,21 +291,53 @@ def save_to_netcdf_per_sequence(output_dir,input_images_,gen_images_,lons,lats,t
         t2.units = 'K'
         t2[:,:,:] = gen_images_[context_frames:,:,:,0]
         print("NetCDF created")
-        
+
         #mean sea level pressure
         msl = nc_file.createVariable("/forecast/{}/MSL".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
         msl.units = 'Pa'
         msl[:,:,:] = gen_images_[context_frames:,:,:,1]
-        
+
         #Geopotential at 500 
         gph500 = nc_file.createVariable("/forecast/{}/GPH500".format(model_name),"f4",("time_forecast","lat","lon"), zlib = True)
         gph500.units = 'm'
         gph500[:,:,:] = gen_images_[context_frames:,:,:,2]        
-        
-        print("{} created".format(output_file))        
-        
+
+        print("{} created".format(output_file)) 
+
     return None
 
+def plot_seq_imgs(imgs,lats,lons,ts,output_png_dir,label="Ground Truth"):
+    """
+    Plot the seq images 
+    """
+
+    if len(np.array(imgs).shape)!=3:raise("img dims should be four: (seq_len,lat,lon)")
+    if np.array(imgs).shape[0]!= len(ts): raise("The len of timestamps should be equal the image seq_len") 
+    fig = plt.figure(figsize=(18,6))
+    gs = gridspec.GridSpec(1, 10)
+    gs.update(wspace = 0., hspace = 0.)
+    xlables = [round(i,2) for i  in list(np.linspace(np.min(lons),np.max(lons),5))]
+    ylabels = [round(i,2) for i  in list(np.linspace(np.max(lats),np.min(lats),5))]
+    for i in range(len(ts)):
+        t = ts[i]
+        #if i==0 : ax1=plt.subplot(gs[i])
+        ax1 = plt.subplot(gs[i])
+        plt.imshow(imgs[i] ,cmap = 'jet', vmin=270, vmax=300)
+        ax1.title.set_text("t = " + t.strftime("%Y%m%d%H"))
+        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
+        if i == 0:
+            plt.setp([ax1], xticks = list(np.linspace(0, len(lons), 5)), xticklabels = xlables, yticks = list(np.linspace(0, len(lats), 5)), yticklabels = ylabels)
+            plt.ylabel(label, fontsize=10)
+    plt.savefig(os.path.join(output_png_dir, label + "_TS_" + str(ts[0]) + ".jpg"))
+    plt.clf()
+    output_fname = label + "_TS_" + ts[0].strftime("%Y%m%d%H") + ".jpg"
+    print("image {} saved".format(output_fname))
+
+    
+def get_persistence(ts):
+    pass
+
+
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--input_dir", type = str, required = True,
@@ -371,7 +404,9 @@ def main():
     num_examples_per_epoch = setup_num_samples_per_epoch(args.num_samples,dataset)
     
     inputs = dataset.make_batch(args.batch_size)
+    print("inputs",inputs)
     input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()}
+    print("input_phs",input_phs)
     
     
     # Build graph
@@ -399,6 +434,7 @@ def main():
     #---Scarlet:20200803    
     #while True:
     #Change True to sample_id<=24 for debugging
+    
     #loop for in samples
     while sample_ind < 5:
         gen_images_stochastic = []
@@ -407,28 +443,52 @@ def main():
         try:
             input_results = sess.run(inputs)
             input_images = input_results["images"]
+            #get the intial times
             t_starts = input_results["T_start"]
-            print("T_starts:",t_starts)
         except tf.errors.OutOfRangeError:
             break
+            
+        #Get prediction values 
         feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()}
+        gen_images = sess.run(model.outputs['gen_images'], feed_dict = feed_dict)#return [batchsize,seq_len,lat,lon,channel]
+        
         #Loop in batch size
         for i in range(args.batch_size):
+            
             #get one seq and the corresponding start time point
             input_images_,t_start = get_one_seq_and_time(input_images,t_starts,i)
             #generate time stamps for sequences
             ts = generate_seq_timestamps(t_start,len_seq=sequence_length)
+            
             #Renormalized data for inputs
-            stat_fl = os.path.join(args.input_dir,"statistics.json")
-            input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"])
-            #TODO: Just for creating the netCDF file and we copy the input_image_denorm as generate_images_denorm before we got our trained data
-            gen_images_denorm = input_images_denorm #(seq,lat,lon,var)            
+            stat_fl = os.path.join(args.input_dir,"pickle/statistics.json")
+            input_images_denorm = denorm_images_all_channels(stat_fl,input_images_,["T2","MSL","gph500"])  
+            print("input_images_denorm",input_images_denorm[0][0])
+                                                             
+            #Renormalized data for inputs
+            gen_images_ = gen_images[i]
+            gen_images_denorm = denorm_images_all_channels(stat_fl,gen_images_,["T2","MSL","gph500"])
+            print("gene_images_denorm:",gen_images_denorm[0][0])
+            
             #Save input to netCDF file
             init_date_str = ts[0].strftime("%Y%m%d%H")
             save_to_netcdf_per_sequence(args.results_dir,input_images_denorm,gen_images_denorm,lons,lats,ts,context_frames,future_length,args.model,fl_name="vfp_{}.nc".format(init_date_str))
+                                                             
+            #Generate images inputs
+            plot_seq_imgs(imgs=input_images_denorm[:context_frames-1,:,:,0],lats=lats,lons=lons,ts=ts[:context_frames-1],label="Ground Truth",output_png_dir=args.results_dir)  
+                                                             
+            #Generate forecast images
+            plot_seq_imgs(imgs=gen_images_denorm[context_frames:,:,:,0],lats=lats,lons=lons,ts=ts[context_frames:],label="Forecast by Model " + args.model,output_png_dir=args.results_dir) 
             
+            #TODO: Scaret plot persistence image
+            #implment get_persistence() function
+
+            #in case of generate the images for all the input, we just generate the first 5 sampe_ind examples for visuliation
+
         sample_ind += args.batch_size
-            #for input_image in input_images_:
+
+
+        #for input_image in input_images_:
 
 #             for stochastic_sample_ind in range(args.num_stochastic_samples):
 #                 input_images_all.extend(input_images)
diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py
index 9c621bbfc46dcf8cebb77b006604ed2f6f50056f..1fb401955c39be4807cf7747e43ed660941cb925 100644
--- a/video_prediction_savp/scripts/train_dummy.py
+++ b/video_prediction_savp/scripts/train_dummy.py
@@ -14,6 +14,8 @@ from video_prediction import datasets, models
 import matplotlib.pyplot as plt
 from json import JSONEncoder
 import pickle as pkl
+
+
 class NumpyArrayEncoder(JSONEncoder):
     def default(self, obj):
         if isinstance(obj, np.ndarray):
@@ -261,9 +263,10 @@ def main():
         print("parameter_count =", sess.run(parameter_count))
         sess.run(tf.global_variables_initializer())
         sess.run(tf.local_variables_initializer())
-        model.restore(sess, args.checkpoint)
+        #model.restore(sess, args.checkpoint)
         sess.graph.finalize()
         start_step = sess.run(model.global_step)
+        print("start_step", start_step)
         # start at one step earlier to log everything without doing any training
         # step is relative to the start_step
         train_losses=[]
@@ -286,6 +289,10 @@ def main():
                 fetches["L_gdl"] = model.L_gdl
                 fetches["L_GAN"]  =model.L_GAN
             
+            if model.__class__.__name__ == "SAVP":
+                #todo
+                pass
+            
             fetches["summary"] = model.summary_op       
             results = sess.run(fetches)
             train_losses.append(results["total_loss"])          
diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
index 7560a225e7651728e2ca8d2107d7f32458106c86..7e3fec28dc28c78b8203e1924f17489af8f5075e 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
@@ -65,8 +65,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
         original_global_variables = tf.global_variables()
         # ARCHITECTURE
         self.convLSTM_network()
-        print("self.x",self.x)
-        print("self.x_hat_context_frames,",self.x_hat_context_frames)
+        #print("self.x",self.x)
+        #print("self.x_hat_context_frames,",self.x_hat_context_frames)
         #self.context_frames_loss = tf.reduce_mean(
         #    tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
         self.total_loss = tf.reduce_mean(
@@ -81,7 +81,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
         self.summary_op = tf.summary.merge_all()
         global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
         self.saveable_variables = [self.global_step] + global_variables
-        return
+        return None
 
 
     @staticmethod