diff --git a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
index 13b93889875779942e5171e5e1d98eebc84fd9f3..3df6f7e2843eb732df4d0be70f410853a0ac2a78 100644
--- a/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
+++ b/video_prediction_savp/scripts/generate_transfer_learning_finetune.py
@@ -357,6 +357,7 @@ def main():
     sess.graph.as_default()
     sess.run(tf.global_variables_initializer())
     sess.run(tf.local_variables_initializer())
+    model.restore(sess, args.checkpoint)
     
     #model.restore(sess, args.checkpoint)#Bing: Todo: 20200728 Let's only focus on true and persistend data
     sample_ind, gen_images_all, persistent_images_all, input_images_all = initia_save_data()
diff --git a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
index 70797c34344b82ba980031ecc08eeb999861f396..a3c9fc3666eb21ef02f9e5f64c8c95a29034d619 100644
--- a/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
+++ b/video_prediction_savp/video_prediction/datasets/era5_dataset_v2.py
@@ -62,7 +62,6 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
         sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
         return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
 
-
     def filter(self, serialized_example):
         return tf.convert_to_tensor(True)
 
@@ -326,7 +325,7 @@ def read_frames_and_save_tf_records(stats,output_dir,input_file, temp_input_file
     #sequence_lengths_file.close()
     return 
 
-def write_sequence_file(output_dir,seq_length):
+def write_sequence_file(output_dir,seq_length,sequences_per_file):
     
     partition_names = ["train","val","test"]
     for partition_name in partition_names:
@@ -334,7 +333,7 @@ def write_sequence_file(output_dir,seq_length):
         tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
         print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
         sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
-        for i in range(tfCounter):
+        for i in range(tfCounter*sequences_per_file):
             sequence_lengths_file.write("%d\n" % seq_length)
         sequence_lengths_file.close()
     
@@ -350,6 +349,7 @@ def main():
     parser.add_argument("-height",type=int,default=64)
     parser.add_argument("-width",type = int,default=64)
     parser.add_argument("-seq_length",type=int,default=20)
+    parser.add_argument("-sequences_per_file",type=int,default=2)
     args = parser.parse_args()
     current_path = os.getcwd()
     #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
@@ -366,13 +366,13 @@ def main():
                # "2012":[1,2,3,4,5,6,7,8,9,10,11,12],
                # "2013_complete":[1,2,3,4,5,6,7,8,9,10,11,12],
                # "2015":[1,2,3,4,5,6,7,8,9,10,11,12],
-                "2017":[1,2,3,4,5,6,7,8,9,10]
+                "2017_test":[1,2,3,4,5,6,7,8,9,10]
                  },
             "val":
-                {"2017":[11]
+                {"2017_test":[11]
                  },
             "test":
-                {"2017":[12]
+                {"2017_test":[12]
                  }
             }
     
@@ -424,7 +424,8 @@ def main():
             message_counter = message_counter + 1 
             print("Message in from slaver",message_in) 
             
-        write_sequence_file(args.output_dir,args.seq_length)
+        write_sequence_file(args.output_dir,args.seq_length,args.sequences_per_file)
+        
         #write_sequence_file   
     else:
         message_in = comm.recv()
@@ -449,7 +450,7 @@ def main():
                read_frames_and_save_tf_records(year=year,month=my_rank,stats=stats,output_dir=save_output_dir, \
                                                input_file=input_file,temp_input_file=temp_file,vars_in=args.variables, \
                                                partition_name=partition_name,seq_length=args.seq_length, \
-                                               height=args.height,width=args.width,sequences_per_file=20)   
+                                               height=args.height,width=args.width,sequences_per_file=args.sequences_per_file)   
                                                   
             print("Year {} finished",year)
         message_out = ("Node:",str(my_rank),"finished","","\r\n")
diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
index c7f3db7ce4fce732312eba0d9f17362faa2e64b5..7560a225e7651728e2ca8d2107d7f32458106c86 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
@@ -41,7 +41,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
             lr: learning rate. if decay steps is non-zero, this is the
                 learning rate for steps <= decay_step.
             max_steps: number of training steps.
-            context_frames: the number of ground-truth frames to pass in at
+            context_frames: the number of ground-truth frames to pass :qin at
                 start. Must be specified during instantiation.
             sequence_length: the number of frames in the video sequence,
                 including the context frames, so this model predicts