diff --git a/Zam347_scripts/DataPreprocess_to_tf.sh b/Zam347_scripts/DataPreprocess_to_tf.sh
index 320307c05e5d4789228e123d50abd7cdb5d39c50..608f95348b25c2169ae2963821639de8947c322b 100755
--- a/Zam347_scripts/DataPreprocess_to_tf.sh
+++ b/Zam347_scripts/DataPreprocess_to_tf.sh
@@ -1,4 +1,4 @@
 #!/bin/bash -x
 
 
-python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 
+python ../video_prediction/datasets/era5_dataset_v2.py /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/hickle/splits/ /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords/ -vars T2 MSL gph500 -height 128 -width 160 -seq_length 20 
diff --git a/Zam347_scripts/train_era5.sh b/Zam347_scripts/train_era5.sh
index 42e28355e6ebe5a99697dc3c57b8b1ad9a859260..0f1b2c8930befb29399d2d943d558ca8d8e412d4 100755
--- a/Zam347_scripts/train_era5.sh
+++ b/Zam347_scripts/train_era5.sh
@@ -2,5 +2,5 @@
 
 
 
-python ../scripts/train_v2.py --input_dir  /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5  --model savp --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /home/${USER}/models/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/ours_savp
+python ../scripts/train_v2.py --input_dir  /home/${USER}/preprocessedData/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/tfrecords --dataset era5  --model savp --model_hparams_dict ../hparams/kth/ours_savp/model_hparams.json --output_dir /home/${USER}/models/era5-Y2015toY2017M01to12-128x160-74d00N71d00E-T_MSL_gph500/ours_savp 
 #srun  python scripts/train.py --input_dir data/era5 --dataset era5  --model savp --model_hparams_dict hparams/kth/ours_savp/model_hparams.json --output_dir logs/era5/ours_savp
diff --git a/video_prediction/datasets/era5_dataset_v2.py b/video_prediction/datasets/era5_dataset_v2.py
index 9ff970517c18155daea43c4646d8c750a3c5a509..606f32f0bb47a66e1190be5b9585e6ef9b5cf752 100644
--- a/video_prediction/datasets/era5_dataset_v2.py
+++ b/video_prediction/datasets/era5_dataset_v2.py
@@ -27,8 +27,9 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
         example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
         dict_message = MessageToDict(tf.train.Example.FromString(example))
         feature = dict_message['features']['feature']
-        image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels'])
-        self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape
+        self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
+        self.image_shape = self.video_shape[1:]
+        self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
 
     def get_default_hparams_dict(self):
         default_hparams = super(ERA5Dataset_v2, self).get_default_hparams_dict()
@@ -64,17 +65,23 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
         def parser(serialized_example):
             seqs = OrderedDict()
             keys_to_features = {
-                # 'width': tf.FixedLenFeature([], tf.int64),
-                # 'height': tf.FixedLenFeature([], tf.int64),
+                'width': tf.FixedLenFeature([], tf.int64),
+                'height': tf.FixedLenFeature([], tf.int64),
                 'sequence_length': tf.FixedLenFeature([], tf.int64),
-                # 'channels': tf.FixedLenFeature([],tf.int64),
+                'channels': tf.FixedLenFeature([],tf.int64),
                 # 'images/encoded':  tf.FixedLenFeature([], tf.string)
                 'images/encoded': tf.VarLenFeature(tf.float32)
             }
+            
             # for i in range(20):
             #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
             parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
+            print ("Parse features", parsed_features)
             seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
+           # width = tf.sparse_tensor_to_dense(parsed_features["width"])
+           # height = tf.sparse_tensor_to_dense(parsed_features["height"])
+           # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
+           # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
             images = []
             # for i in range(20):
             #    images.append(parsed_features["images/encoded"].values[i])
@@ -85,8 +92,8 @@ class ERA5Dataset_v2(VarLenFeatureVideoDataset):
             # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32)
 
             # images = seq
-            images = tf.reshape(seq, [20, 128, 160, 3], name = "reshape_new")
-            print("IMAGES", images)
+            print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
+            images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
             seqs["images"] = images
             return seqs
         filenames = self.filenames
@@ -154,7 +161,7 @@ def save_tf_record(output_fname, sequences):
             example = tf.train.Example(features=features)
             writer.write(example.SerializeToString())
 
-def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,N_seq,sequences_per_file=128,**kwargs):#Bing: original 128
+def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,seq_length=20,sequences_per_file=128,height=64,width=64,channels=3,**kwargs):#Bing: original 128
     # ML 2020/04/08:
     # Include vars_in for more flexible data handling (normalization and reshaping)
     # and optional keyword argument for kind of normalization
@@ -189,15 +196,15 @@ def read_frames_and_save_tf_records(output_dir,input_dir,partition_name,vars_in,
     sequence_iter = 0
     sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
     X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl"))
-    X_possible_starts = [i for i in range(len(X_train) - N_seq)]
+    X_possible_starts = [i for i in range(len(X_train) - seq_length)]
     for X_start in X_possible_starts:
         print("Interation", sequence_iter)
-        X_end = X_start + N_seq
+        X_end = X_start + seq_length
         #seq = X_train[X_start:X_end, :, :,:]
         seq = X_train[X_start:X_end,:,:]
         #print("*****len of seq ***.{}".format(len(seq)))
         #seq = list(np.array(seq).reshape((len(seq), 64, 64, 3)))
-        seq = list(np.array(seq).reshape((len(seq), 128, 160,nvars)))
+        seq = list(np.array(seq).reshape((seq_length, height, width, nvars)))
         if not sequences:
             last_start_sequence_iter = sequence_iter
             print("reading sequences starting at sequence %d" % sequence_iter)
@@ -230,8 +237,9 @@ def main():
     # ML 2020/04/08 S
     # Add vars for ensuring proper normalization and reshaping of sequences
     parser.add_argument("-vars","--variables",dest="variables", nargs='+', type=str, help="Names of input variables.")
-    # parser.add_argument("image_size_h", type=int)
-    # parser.add_argument("image_size_v", type = int)
+    parser.add_argument("-height",type=int,default=64)
+    parser.add_argument("-width",type = int,default=64)
+    parser.add_argument("-seq_length",type=int,default=20)
     args = parser.parse_args()
     current_path = os.getcwd()
     #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits"
@@ -239,7 +247,7 @@ def main():
     partition_names = ['train','val',  'test'] #64,64,3 val has issue#
   
     for partition_name in partition_names:
-        read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, N_seq=20, sequences_per_file=2) #Bing: Todo need check the N_seq
+        read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,vars_in=args.variables,partition_name=partition_name, seq_length=args.seq_length,height=args.height,width=args.width,sequences_per_file=2) #Bing: Todo need check the N_seq
         #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5
 
 if __name__ == '__main__':