diff --git a/video_prediction_savp/HPC_scripts/train_era5.sh b/video_prediction_savp/HPC_scripts/train_era5.sh
index f605866056f6b2d9fa179a00850468fee0c72d87..5173564faae730cda10ac3acc072fe9ed43cb7b3 100755
--- a/video_prediction_savp/HPC_scripts/train_era5.sh
+++ b/video_prediction_savp/HPC_scripts/train_era5.sh
@@ -37,7 +37,7 @@ fi
 source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/
 destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/
 
-# for choosing the model
+# for choosing the model for choosing the model, convLSTM,savp, mcnet,vae
 model=convLSTM
 model_hparams=../hparams/era5/${model}/model_hparams.json
 
diff --git a/video_prediction_savp/hparams/era5/mcnet/model_hparams.json b/video_prediction_savp/hparams/era5/mcnet/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..c2edaad9f9ac158f6e7b8d94bb81db16d55d05e8
--- /dev/null
+++ b/video_prediction_savp/hparams/era5/mcnet/model_hparams.json
@@ -0,0 +1,12 @@
+
+{
+    "batch_size": 10,
+    "lr": 0.001,
+    "max_epochs":2,
+    "context_frames":10,
+    "sequence_length":20
+
+}
+
+
+
diff --git a/video_prediction_savp/hparams/era5/savp/model_hparams.json b/video_prediction_savp/hparams/era5/savp/model_hparams.json
new file mode 100644
index 0000000000000000000000000000000000000000..641ffb36f764f5ae720a534d7d9eef0ebad644d8
--- /dev/null
+++ b/video_prediction_savp/hparams/era5/savp/model_hparams.json
@@ -0,0 +1,18 @@
+{
+    "batch_size": 4,
+    "lr": 0.0002,
+    "beta1": 0.5,
+    "beta2": 0.999,
+    "l1_weight": 100.0,
+    "l2_weight": 0.0,
+    "kl_weight": 0.01,
+    "video_sn_vae_gan_weight": 0.1,
+    "video_sn_gan_weight": 0.1,
+    "vae_gan_feature_cdist_weight": 10.0,
+    "gan_feature_cdist_weight": 0.0,
+    "state_weight": 0.0,
+    "nz": 32,
+    "max_epochs":2
+}
+
+
diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py
index 1fb401955c39be4807cf7747e43ed660941cb925..4e30c4ce8e65799b88defa7c331d08dd0469c079 100644
--- a/video_prediction_savp/scripts/train_dummy.py
+++ b/video_prediction_savp/scripts/train_dummy.py
@@ -199,7 +199,7 @@ def main():
     parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") 
     parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
 
-    parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
+    parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="fraction of gpu memory to use")
     parser.add_argument("--seed",default=1234, type=int)
 
     args = parser.parse_args()
@@ -232,6 +232,7 @@ def main():
     inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size)
     
     #build model graph
+    del inputs["T_start"]
     model.build_graph(inputs)
     
     #save all the model, data params to output dirctory
@@ -255,6 +256,7 @@ def main():
     print ("number of exmaples per epoch:",num_examples_per_epoch)
     steps_per_epoch = int(num_examples_per_epoch/batch_size)
     total_steps = steps_per_epoch * max_epochs
+    global_step = tf.train.get_or_create_global_step()
     #mock total_steps only for fast debugging
     #total_steps = 10
     print ("Total steps for training:",total_steps)
@@ -263,63 +265,77 @@ def main():
         print("parameter_count =", sess.run(parameter_count))
         sess.run(tf.global_variables_initializer())
         sess.run(tf.local_variables_initializer())
-        #model.restore(sess, args.checkpoint)
+        model.restore(sess, args.checkpoint)
         sess.graph.finalize()
-        start_step = sess.run(model.global_step)
+        #start_step = sess.run(model.global_step)
+        start_step = sess.run(global_step)
         print("start_step", start_step)
         # start at one step earlier to log everything without doing any training
         # step is relative to the start_step
         train_losses=[]
         val_losses=[]
         run_start_time = time.time()        
-        for step in range(total_steps):
-            global_step = sess.run(model.global_step)
-            print ("global_step:", global_step)
+        for step in range(start_step,total_steps):
+            #global_step = sess.run(global_step):q
+ 
+            print ("step:", step)
             val_handle_eval = sess.run(val_handle)
-            
+
             #Fetch variables in the graph
-            fetches = {"global_step":model.global_step}
-            fetches["train_op"] = model.train_op
-            #fetches["latent_loss"] = model.latent_loss
-            fetches["total_loss"] = model.total_loss
 
-            #fetch the specific loss function only for mcnet
-            if model.__class__.__name__ == "McNetVideoPredictionModel":
-                fetches["L_p"] = model.L_p
-                fetches["L_gdl"] = model.L_gdl
-                fetches["L_GAN"]  =model.L_GAN
-            
-            if model.__class__.__name__ == "SAVP":
-                #todo
-                pass
+            fetches = {"train_op": model.train_op}
+            #fetches["latent_loss"] = model.latent_loss
+            fetches["summary"] = model.summary_op 
             
-            fetches["summary"] = model.summary_op       
-            results = sess.run(fetches)
-            train_losses.append(results["total_loss"])          
-            #Fetch losses for validation data
-            val_fetches = {}
-            #val_fetches["latent_loss"] = model.latent_loss
-            val_fetches["total_loss"] = model.total_loss
+            if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
+                fetches["global_step"] = model.global_step
+                fetches["total_loss"] = model.total_loss
+                #fetch the specific loss function only for mcnet
+                if model.__class__.__name__ == "McNetVideoPredictionModel":
+                    fetches["L_p"] = model.L_p
+                    fetches["L_gdl"] = model.L_gdl
+                    fetches["L_GAN"]  =model.L_GAN                    
+                results = sess.run(fetches)
+                train_losses.append(results["total_loss"])
+                #Fetch losses for validation data
+                val_fetches = {}
+                #val_fetches["latent_loss"] = model.latent_loss
+                val_fetches["total_loss"] = model.total_loss
+
+
+            if model.__class__.__name__ == "SAVPVideoPredictionModel":
+                fetches['d_loss'] = model.d_loss
+                fetches['g_loss'] = model.g_loss
+                fetches['d_losses'] = model.d_losses
+                fetches['g_losses'] = model.g_losses
+                results = sess.run(fetches)
+                train_losses.append(results["g_losses"])
+                val_fetches = {}
+                #val_fetches["latent_loss"] = model.latent_loss
+                #For SAVP the total loss is the generator loses
+                val_fetches["total_loss"] = model.g_losses
+
             val_fetches["summary"] = model.summary_op
             val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval})
             val_losses.append(val_results["total_loss"])
-            
+
             summary_writer.add_summary(results["summary"])
             summary_writer.add_summary(val_results["summary"])
             summary_writer.flush()
-             
+
             # global_step will have the correct step count if we resume from a checkpoint
             # global step is read before it's incemented
-            train_epoch = global_step/steps_per_epoch
-            print("progress  global step %d  epoch %0.1f" % (global_step + 1, train_epoch))
-
+            train_epoch = step/steps_per_epoch
+            print("progress  global step %d  epoch %0.1f" % (step + 1, train_epoch))
             if model.__class__.__name__ == "McNetVideoPredictionModel":
-              print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"]))
+                print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"]))
             elif model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
                 print ("Total_loss:{}".format(results["total_loss"]))
+            elif model.__class__.__name__ == "SAVPVideoPredictionModel":
+                print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"]))
             else:
                 print ("The model name does not exist")
-            
+
             #print("saving model to", args.output_dir)
             saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)#
         train_time = time.time() - run_start_time
diff --git a/video_prediction_savp/video_prediction/models/base_model.py b/video_prediction_savp/video_prediction/models/base_model.py
index 0ebe228fcc9c90addf610bed44bb46f090c7e514..846621d8ca1e235c39618951be86fe184a2d974d 100644
--- a/video_prediction_savp/video_prediction/models/base_model.py
+++ b/video_prediction_savp/video_prediction/models/base_model.py
@@ -366,7 +366,7 @@ class VideoPredictionModel(BaseVideoPredictionModel):
             end_lr=0.0,
             decay_steps=(200000, 300000),
             lr_boundaries=(0,),
-            max_steps=350000,
+            max_epochs=35,
             beta1=0.9,
             beta2=0.999,
             context_frames=-1,
diff --git a/video_prediction_savp/video_prediction/models/mcnet_model.py b/video_prediction_savp/video_prediction/models/mcnet_model.py
index 725ce4f46a301b6aa07f3d50ef811584d5b502db..7a376cb7b2ddb4f46b3ad67a6b2cf7e866823427 100644
--- a/video_prediction_savp/video_prediction/models/mcnet_model.py
+++ b/video_prediction_savp/video_prediction/models/mcnet_model.py
@@ -72,7 +72,7 @@ class McNetVideoPredictionModel(BaseVideoPredictionModel):
         hparams = dict(
             batch_size=16,
             lr=0.001,
-            max_steps=350000,
+            max_epochs=350000,
             context_frames = 10,
             sequence_length = 20,
             nz = 16,
@@ -96,7 +96,8 @@ class McNetVideoPredictionModel(BaseVideoPredictionModel):
         self.is_train = True
        
 
-        self.global_step = tf.Variable(0, name='global_step', trainable=False)
+        #self.global_step = tf.Variable(0, name='global_step', trainable=False)
+        self.global_step = tf.train.get_or_create_global_step()
         original_global_variables = tf.global_variables()
 
         # self.xt = tf.placeholder(tf.float32, self.xt_shape, name='xt')
diff --git a/video_prediction_savp/video_prediction/models/savp_model.py b/video_prediction_savp/video_prediction/models/savp_model.py
index ca8acd3f32a5ea1772c9fbf36003149acfdcb950..c510d050c89908d0e06fe0f1a66e355e61c90530 100644
--- a/video_prediction_savp/video_prediction/models/savp_model.py
+++ b/video_prediction_savp/video_prediction/models/savp_model.py
@@ -688,6 +688,7 @@ class SAVPCell(tf.nn.rnn_cell.RNNCell):
 
 def generator_given_z_fn(inputs, mode, hparams):
     # all the inputs needs to have the same length for unrolling the rnn
+    print("inputs.items",inputs.items())
     inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
               for name, input in inputs.items()}
     cell = SAVPCell(inputs, mode, hparams)
diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
index 7e3fec28dc28c78b8203e1924f17489af8f5075e..6ca386fcda740b7a3da1a16d0ad84dcd08fe653a 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
@@ -60,8 +60,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
 
     def build_graph(self, x):
         self.x = x["images"]
-
-        self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
+        #self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
+        self.global_step = tf.train.get_or_create_global_step()
         original_global_variables = tf.global_variables()
         # ARCHITECTURE
         self.convLSTM_network()
diff --git a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py
index eec5598305044226280080d630313487c7d847a4..81c556cea556aa4a7415f33ae4de817023c89d9b 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py
@@ -63,7 +63,7 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
             end_lr=0.0,
             decay_steps=(200000, 300000),
             lr_boundaries=(0,),
-            max_steps=350000,
+            max_epochs=35,
             nz=10,
             context_frames=-1,
             sequence_length=-1,
@@ -71,42 +71,16 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
         )
         return dict(itertools.chain(default_hparams.items(), hparams.items()))
 
-    def build_graph(self,x):
-        
-        
-        
-
-
-
-        tf.set_random_seed(12345)
+    def build_graph(self,x)  
         self.x = x["images"]
-       
-
-        self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
+        #self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
+        self.global_step = tf.train.get_or_create_global_step()
         original_global_variables = tf.global_variables()
-        self.increment_global_step = tf.assign_add(self.global_step, 1, name = 'increment_global_step')
-
         self.x_hat, self.z_log_sigma_sq, self.z_mu = self.vae_arc_all()
-       
-     
-    
-   
-  
- 
-
-
-
-
-
         self.recon_loss = tf.reduce_mean(tf.square(self.x[:, 1:, :, :, 0] - self.x_hat[:, :-1, :, :, 0]))
-
-
-
-
-
         latent_loss = -0.5 * tf.reduce_sum(
             1 + self.z_log_sigma_sq - tf.square(self.z_mu) -
-            tf.exp(self.z_log_sigma_sq), axis = 1)
+            tf.exp(self.z_log_sigma_sq), axis=1)
         self.latent_loss = tf.reduce_mean(latent_loss)
         self.total_loss = self.recon_loss + self.latent_loss
         self.train_op = tf.train.AdamOptimizer(
@@ -125,51 +99,33 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
         self.loss_summary = tf.summary.scalar("total_loss", self.latent_loss)
         self.summary_op = tf.summary.merge_all()
 
-
-
         self.outputs = {}
         self.outputs["gen_images"] = self.x_hat
         global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
         self.saveable_variables = [self.global_step] + global_variables
 
-        return
+        return None
 
 
     @staticmethod
     def vae_arc3(x,l_name=0,nz=16):
         seq_name = "sq_" + str(l_name) + "_"
-        
-        conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1")
-
 
+        conv1 = ld.conv_layer(x, 3, 2, 8, seq_name + "encode_1")
         conv2 = ld.conv_layer(conv1, 3, 1, 8, seq_name + "encode_2")
-
-
         conv3 = ld.conv_layer(conv2, 3, 2, 8, seq_name + "encode_3")
-
-
         conv4 = tf.layers.Flatten()(conv3)
-
         conv3_shape = conv3.get_shape().as_list()
-
-        
         z_mu = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m")
         z_log_sigma_sq = ld.fc_layer(conv4, hiddens = nz, idx = seq_name + "enc_fc4_m"'enc_fc4_sigma')
         eps = tf.random_normal(shape = tf.shape(z_log_sigma_sq), mean = 0, stddev = 1, dtype = tf.float32)
-        z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps
-        
-        z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1")
-        
-        
+        z = z_mu + tf.sqrt(tf.exp(z_log_sigma_sq)) * eps        
+        z2 = ld.fc_layer(z, hiddens = conv3_shape[1] * conv3_shape[2] * conv3_shape[3], idx = seq_name + "decode_fc1") 
         z3 = tf.reshape(z2, [-1, conv3_shape[1], conv3_shape[2], conv3_shape[3]])
-
         conv5 = ld.transpose_conv_layer(z3, 3, 2, 8,
                                         seq_name + "decode_5")  
-
         conv6  = ld.transpose_conv_layer(conv5, 3, 1, 8,
                                         seq_name + "decode_6")
-        
-
         x_hat = ld.transpose_conv_layer(conv6, 3, 2, 3, seq_name + "decode_8")
 
         return x_hat, z_mu, z_log_sigma_sq, z
@@ -186,6 +142,5 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel):
         x_hat = tf.stack(X, axis = 1)
         z_log_sigma_sq_all = tf.stack(z_log_sigma_sq_all, axis = 1)
         z_mu_all = tf.stack(z_mu_all, axis = 1)
-       
-      
+
         return x_hat, z_log_sigma_sq_all, z_mu_all