diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
index 3d2bd6d7462320baaa6bc8f5414506491d6d0e1f..adfdf06539174a78be375fa1f2416dade078a1c5 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
@@ -29,6 +29,7 @@ class MovingMnist(object):
         self.input_dir = input_dir
         self.mode = mode 
         self.seed = seed
+        self.sequence_length = None                             # will be set in get_example_info
         if self.mode not in ('train', 'val', 'test'):
             raise ValueError('Invalid mode %s' % self.mode)
         if not os.path.exists(self.input_dir):
@@ -41,8 +42,6 @@ class MovingMnist(object):
         self.get_tfrecords_filename_base_datasplit()
         self.get_example_info()
 
-
-
     def get_datasplit(self):
         """
         Get the datasplit json file
@@ -51,8 +50,6 @@ class MovingMnist(object):
             self.d = json.load(f)
         return self.d
 
-
-
     def get_model_hparams_dict(self):
         """
         Get model_hparams_dict from json file
@@ -142,21 +139,23 @@ class MovingMnist(object):
         dict_message = MessageToDict(tf.train.Example.FromString(example))
         feature = dict_message['features']['feature']
         print("features in dataset:",feature.keys())
-        self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
-        self.image_shape = self.video_shape[1:]
-
+        video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height',
+                                                                                        'width', 'channels'])
+        self.sequence_length = video_shape[0]
+        self.image_shape = video_shape[1:]
 
     def num_examples_per_epoch(self):
         """
         Calculate how many tfrecords samples in the train/val/test
         """
-        #count how many tfrecords files for train/val/testing
+        # count how many tfrecords files for train/val/testing
         len_fnames = len(self.filenames)
-        seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt')
-        with open(seq_len_file, 'r') as sequence_lengths_file:
-             sequence_lengths = sequence_lengths_file.readlines()
-        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
-        self.num_examples_per_epoch  = len_fnames * sequence_lengths[0]
+        num_seq_file = os.path.join(self.input_dir, 'number_sequences.txt')
+        with open(num_seq_file, 'r') as dfile:
+             num_seqs = dfile.readlines()
+        num_sequences = [int(num_seq.strip()) for num_seq in num_seqs]
+        self.num_examples_per_epoch  = len_fnames * num_sequences[0]
+
         return self.num_examples_per_epoch
 
 
@@ -181,8 +180,10 @@ class MovingMnist(object):
              }
             parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
             seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
-            print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
-            images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
+            print("Image shape {}, {},{},{}".format(self.sequence_length,self.image_shape[0],self.image_shape[1],
+                                                    self.image_shape[2]))
+            images = tf.reshape(seq, [self.sequence_length,self.image_shape[0],self.image_shape[1],
+                                      self.image_shape[2]], name = "reshape_new")
             seqs["images"] = images
             return seqs
         filenames = self.filenames