diff --git a/test/run_pytest.sh b/test/run_pytest.sh
index b440bebf7026eb777f0bd3aa82b7894f5c0e540c..ed2b531562282d8fc6a34c07cf8e46cad1a83460 100644
--- a/test/run_pytest.sh
+++ b/test/run_pytest.sh
@@ -21,8 +21,10 @@ fi
 #python -m pytest  test_prepare_era5_data.py 
 ##Test for preprocess_step1
 #python -m pytest  test_process_netCDF_v2.py
-
 source ../video_prediction_tools/env_setup/modules_train.sh
+##Test for preprocess moving mnist
+#python -m pytest test_prepare_moving_mnist_data.py
+python -m pytest test_train_moving_mnist_data.py 
 #Test for process step2
 #python -m pytest test_data_preprocess_step2.py
 #python -m pytest test_era5_data.py
@@ -31,5 +33,5 @@ source ../video_prediction_tools/env_setup/modules_train.sh
 #rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/* 
 #python -m pytest test_train_model_era5.py
 #python -m pytest test_vanilla_vae_model.py
-python -m pytest test_visualize_postprocess.py
+#python -m pytest test_visualize_postprocess.py
 #python -m pytest test_meta_postprocess.py
diff --git a/video_prediction_tools/data_preprocess/dataset_options.py b/video_prediction_tools/data_preprocess/dataset_options.py
index 28dffb6c8879bd934c6a8f7169ee0a6bcf679999..5e9729d693720e0e1380170a436980fdbeb900e7 100644
--- a/video_prediction_tools/data_preprocess/dataset_options.py
+++ b/video_prediction_tools/data_preprocess/dataset_options.py
@@ -16,4 +16,4 @@ def known_datasets():
         #        "era5_anomaly":"ERA5Dataset_v2_anomaly",
     }
 
-    return dataset_mappings
\ No newline at end of file
+    return dataset_mappings
diff --git a/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py b/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..444a6e0bdb0c11b25d19984236a71a3cefb9a2fa
--- /dev/null
+++ b/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py
@@ -0,0 +1,129 @@
+"""
+Class and functions required for preprocessing Moving mnist data from .npz to TFRecords
+"""
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong, Karim Mache"
+__date__ = "2021_05_04"
+
+
+import os
+import numpy as np
+import tensorflow as tf
+import argparse
+from model_modules.video_prediction.datasets.moving_mnist import MovingMnist
+
+
+class MovingMnist2Tfrecords(MovingMnist):
+
+    def __init__(self, input_dir=None, dest_dir=None, sequences_per_file=128):
+        """
+        This class is used for converting .npz files to tfrecords
+
+        :param input_dir: str, the path direcotry to the file of npz
+        :param dest_dir: the output  directory to save TFrecords.
+        :param sequence_length: int, default is 20, the sequence length per sample
+        :param sequences_per_file:int, how many sequences/samples per tfrecord to be saved
+        """
+        self.input_dir = input_dir
+        self.output_dir = dest_dir
+        os.makedirs(self.output_dir, exist_ok = True)
+        self.sequences_per_file = sequences_per_file
+        self.write_sequence_file()
+
+
+    def __call__(self):
+        """
+        steps to process npy file to tfrecords
+        :return: None
+        """
+        self.read_npz_file()
+        self.save_npz_to_tfrecords()
+
+    def read_npz_file(self):
+        self.data = np.load(os.path.join(self.input_dir, "mnist_test_seq.npy"))
+        print("data in minist_test_Seq shape", self.data.shape)
+        return None
+
+    def save_npz_to_tfrecords(self):  # Bing: original 128
+        """
+        Read the moving_mnst data which is npz format, and save it to tfrecords files
+        The shape of dat_npz is [seq_length,number_samples,height,width]
+        moving_mnst only has one channel
+        """
+        idx = 0
+        num_samples = self.data.shape[1]
+        if len(self.data.shape) == 4:
+            #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel]
+            self.data = np.expand_dims(self.data, axis = 4)
+        elif len(self.data.shape) == 5:
+            pass
+        else:
+            raise (f"The shape of input movning mnist npz file is {len(self.data.shape)} which is not either 4 or 5, please further check your data source!")
+
+        self.data = self.data.astype(np.float32)
+        self.data/= 255.0  # normalize RGB codes by dividing it to the max RGB value
+        while idx < num_samples - self.sequences_per_file:
+            sequences = self.data[:, idx:idx+self.sequences_per_file, :, :, :]
+            output_fname = 'sequence_index_{}_to_{}.tfrecords'.format(idx, idx + self.sequences_per_file-1)
+            output_fname = os.path.join(self.output_dir, output_fname)
+            MovingMnist2Tfrecords.save_tf_record(output_fname, sequences)
+            idx = idx + self.sequences_per_file
+        return None
+
+    @staticmethod
+    def save_tf_record(output_fname, sequences):
+        with tf.python_io.TFRecordWriter(output_fname) as writer:
+            for i in range(np.array(sequences).shape[1] - 1):
+                sequence = sequences[:, i, :, :, :]
+                num_frames = len(sequence)
+                height, width = sequence[0, :, :, 0].shape
+                encoded_sequence = np.array([list(image) for image in sequence])
+                features = tf.train.Features(feature = {
+                    'sequence_length': _int64_feature(num_frames),
+                    'height': _int64_feature(height),
+                    'width': _int64_feature(width),
+                    'channels': _int64_feature(1),
+                    'images/encoded': _floats_feature(encoded_sequence.flatten()),
+                })
+                example = tf.train.Example(features = features)
+                writer.write(example.SerializeToString())
+
+    def write_sequence_file(self):
+        """
+        Generate a txt file, with the numbers of sequences for each tfrecords file.
+        This is mainly used for calculting the number of samples for each epoch during training epoch
+        """
+
+        with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file:
+            seq_file.write("%d\n" % self.sequences_per_file)
+
+
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _bytes_list_feature(values):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+def _floats_feature(value):
+    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-input_dir", type=str, help="The input directory that contains the movning mnnist npz file", default="/p/largedata/datasets/moving-mnist/mnist_test_seq.npy")
+    parser.add_argument("-output_dir", type=str)
+    parser.add_argument("-sequences_per_file", type=int, default=2)
+    args = parser.parse_args()
+    inst = MovingMnist2Tfrecords(args.input_dir, args.output_dir, args.sequence_per_file)
+    inst()
+
+
+if __name__ == '__main__':
+     main()
diff --git a/video_prediction_tools/data_split/moving_mnist/datasplit.json b/video_prediction_tools/data_split/moving_mnist/datasplit.json
index 217b285d8e105debbe7841735eb50786762ace19..0c199e18b6685404b1e137a139985f1b511bc4c4 100644
--- a/video_prediction_tools/data_split/moving_mnist/datasplit.json
+++ b/video_prediction_tools/data_split/moving_mnist/datasplit.json
@@ -1,11 +1,10 @@
 {
     "train":{ 
-             "index1":[0,100],
-	     "index2":[150,200]
+             "index1":[0,99]
              },
      "val":
              {
-             "index1":[110,149]
+             "index1":[100,149]
              },
       "test":
              {
diff --git a/video_prediction_tools/data_split/moving_mnist/datasplit_template.json b/video_prediction_tools/data_split/moving_mnist/datasplit_template.json
index 11407a0439e7bd3d1397d6dfce9cce660786a866..890b7e4599d429a0ee91fd2ebf79ecf345168dda 100644
--- a/video_prediction_tools/data_split/moving_mnist/datasplit_template.json
+++ b/video_prediction_tools/data_split/moving_mnist/datasplit_template.json
@@ -7,12 +7,11 @@
 #              Be aware that this is a prue data file, i.e. do not make use of any Python-functions such as np.range or similar here!
 {
     "train":{ 
-             "index1":[0,100],
-	     "index2":[150,200]
+             "index1":[0,100]
              },
      "val":
              {
-             "index1":[110,149]
+             "index1":[100,149]
              },
       "test":
              {
diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json
new file mode 100644
index 0000000000000000000000000000000000000000..b59f6cb2ee96162b2eb6014d7ca6bd37f54d4218
--- /dev/null
+++ b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json
@@ -0,0 +1,12 @@
+
+{
+    "batch_size": 10,
+    "lr": 0.001,
+    "max_epochs":20,
+    "context_frames":10,
+    "sequence_length":20,
+    "loss_fun":"cross_entropy"
+}
+
+
+
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
index 7a70351e7808103e9a3e02e65654f151213c45ec..cd0ec2b230169016cc10aee5ee2ff3d7e4fc611b 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
@@ -18,13 +18,12 @@ def get_dataset_class(dataset):
     if dataset_class is None:
         raise ValueError('Invalid dataset %s' % dataset)
     else:
-        # ERA5Dataset does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which
-        # does not need to be a subclass of BaseVideoDataset
-        if not dataset_class == "ERA5Dataset":
-            dataset_class = globals().get(dataset_class)
-            if not issubclass(dataset_class,BaseVideoDataset):
-                raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class))
-        else:
-            dataset_class = globals().get(dataset_class)
+        # ERA5Dataset  movning_mnist does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which does not need to be a subclass of BaseVideoDataset
+        #if not dataset_class == "ERA5Dataset" or not dataset_class == "MovingMnist":
+        #    dataset_class = globals().get(dataset_class)
+        #    if not issubclass(dataset_class,BaseVideoDataset):
+        #        raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class))
+        #else:
+        dataset_class = globals().get(dataset_class)
 
     return dataset_class
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
index b1136172b203218b5bbc3b052e0d454c2c5bd60f..40df33aaaf82d4764d8b0d55c8a1bed55131e963 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
@@ -8,41 +8,117 @@ import re
 import tensorflow as tf
 import numpy as np
 import skimage.io
-from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
+from collections import OrderedDict
+from tensorflow.contrib.training import HParams
+from google.protobuf.json_format import MessageToDict
+
+
+class KTHVideoDataset(object):
+    def __init__(self,input_dir=None,datasplit_config=None,hparams_dict_config=None, mode='train',seed=None):
+        """
+        This class is used for preparing data for training/validation and test models
+        args:
+            input_dir            : the path of tfrecords files
+            datasplit_config     : the path pointing to the datasplit_config json file
+            hparams_dict_config  : the path to the dict that contains hparameters,
+            mode                 : string, "train","val" or "test"
+            seed                 : int, the seed for dataset 
+        """
+        self.input_dir = input_dir
+        self.datasplit_config = datasplit_config
+        self.mode = mode
+        self.seed = seed
+        if self.mode not in ('train', 'val', 'test'):
+            raise ValueError('Invalid mode %s' % self.mode)
+        if not os.path.exists(self.input_dir):
+            raise FileNotFoundError("input_dir %s does not exist" % self.input_dir)
+        self.datasplit_dict_path = datasplit_config
+        self.data_dict = self.get_datasplit()
+        self.hparams_dict_config = hparams_dict_config
+        self.hparams_dict = self.get_model_hparams_dict()
+        self.hparams = self.parse_hparams() 
+        self.get_tfrecords_filesnames_base_datasplit()
+        self.get_example_info()
+
+
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
 
 
-class KTHVideoDataset(VarLenFeatureVideoDataset):
-    def __init__(self, *args, **kwargs):
-        super(KTHVideoDataset, self).__init__(*args, **kwargs)
-        from google.protobuf.json_format import MessageToDict
-        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
-        dict_message = MessageToDict(tf.train.Example.FromString(example))
-        feature = dict_message['features']['feature']
-        image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels'])
-
-        self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape
-
     def get_default_hparams_dict(self):
-        default_hparams = super(KTHVideoDataset, self).get_default_hparams_dict()
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            sequence_length : the number of frames in the video sequence 
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+        """
         hparams = dict(
             context_frames=10,
             sequence_length=20,
-            long_sequence_length=40,
-            force_time_shift=True,
-            shuffle_on_val=True,
-            use_state=False,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            loss_fun = "rmse",
+            shuffle_on_val= True,
         )
-        return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-    @property
-    def jpeg_encoding(self):
-        return False
+        return hparams
+
+
+
+
+    def get_datasplit(self):
+        """
+        Get the datasplit json file
+        """
+
+        with open(self.datasplit_dict_path) as f:
+            self.d = json.load(f)
+        return self.d
+
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+      
+    def get_tfrecords_filesnames_base_datasplit(self):
+        """
+        Get  absolute .tfrecord path names based on the data splits patterns
+        """
+        self.filenames = []
+        self.data_mode = self.data_dict[self.mode]
+        self.tf_names = []
+        for year, months in self.data_mode.items():
+            for month in months:
+                tf_files = "sequence_Y_{}_M_{}_*_to_*.tfrecord*".format(year,month)    
+                self.tf_names.append(tf_files)
+        # look for tfrecords in input_dir and input_dir/mode directories
+        for files in self.tf_names:
+            self.filenames.extend(glob.glob(os.path.join(self.input_dir, files)))
+        if self.filenames:
+            self.filenames = sorted(self.filenames)  # ensures order is the same across systems
+        if not self.filenames:
+            raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir)
 
     def num_examples_per_epoch(self):
-        with open(os.path.join(self.input_dir, 'number_sequences.txt'), 'r') as sequence_lengths_file:
-            sequence_lengths = sequence_lengths_file.readlines()
+        """
+        Calculate how many tfrecords samples in the train/val/test 
+        """
+        #count how many tfrecords files for train/val/testing
+        len_fnames = len(self.filenames)
+        seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt')
+        with open(seq_len_file, 'r') as sequence_lengths_file:
+             sequence_lengths = sequence_lengths_file.readlines()
         sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
-        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
+        self.num_examples_per_epoch  = len_fnames * sequence_lengths[0]
+        return self.num_examples_per_epoch 
 
 
 def _bytes_feature(value):
@@ -62,17 +138,12 @@ def partition_data(input_dir):
     fnames = glob.glob(os.path.join(input_dir, '*/*'))
     fnames = [fname for fname in fnames if os.path.isdir(fname)]
     print("frames",fnames[0])
-
     persons = [re.match('person(\d+)_\w+_\w+', os.path.split(fname)[1]).group(1) for fname in fnames]
     persons = np.array([int(person) for person in persons])
-
     train_mask = persons <= 16
-
     train_fnames = [fnames[i] for i in np.where(train_mask)[0]]
     test_fnames = [fnames[i] for i in np.where(~train_mask)[0]]
-
     random.shuffle(train_fnames)
-
     pivot = int(0.95 * len(train_fnames))
     train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:]
     return train_fnames, val_fnames, test_fnames
@@ -96,41 +167,42 @@ def save_tf_record(output_fname, sequences):
             writer.write(example.SerializeToString())
 
 
-def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128):
-    partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test
-    sequences = []
-    sequence_iter = 0
-    sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
-    for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person
-        meta_partition_name = partition_name if partition_name == 'test' else 'train'
-        meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' %
-                                  (meta_partition_name, image_size, image_size))
-        with open(meta_fname, "rb") as f:
-            data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys.  "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png
-
-        vid = os.path.split(video_dir)[1]
-        (d,) = [d for d in data if d['vid'] == vid]
-        for frame_fnames_iter, frame_fnames in enumerate(d['files']):
-            frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames]
-            frames = skimage.io.imread_collection(frame_fnames)
-            # they are grayscale images, so just keep one of the channels
-            frames = [frame[..., 0:1] for frame in frames]
-
-            if not sequences: #The length of the sequence in sequences could be different
-                last_start_sequence_iter = sequence_iter
-                print("reading sequences starting at sequence %d" % sequence_iter)
-
-            sequences.append(frames)
-            sequence_iter += 1
-            sequence_lengths_file.write("%d\n" % len(frames))
-
-            if (len(sequences) == sequences_per_file or
-                    (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))):
-                output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
-                output_fname = os.path.join(output_dir, output_fname)
-                save_tf_record(output_fname, sequences)
-                sequences[:] = []
-    sequence_lengths_file.close()
+
+    def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128):
+        partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test
+        sequences = []
+        sequence_iter = 0
+        sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
+        for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person
+            meta_partition_name = partition_name if partition_name == 'test' else 'train'
+            meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' %
+                                      (meta_partition_name, image_size, image_size))
+            with open(meta_fname, "rb") as f:
+                data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys.  "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png
+
+            vid = os.path.split(video_dir)[1]
+            (d,) = [d for d in data if d['vid'] == vid]
+            for frame_fnames_iter, frame_fnames in enumerate(d['files']):
+                frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames]
+                frames = skimage.io.imread_collection(frame_fnames)
+                # they are grayscale images, so just keep one of the channels
+                frames = [frame[..., 0:1] for frame in frames]
+
+                if not sequences: #The length of the sequence in sequences could be different
+                    last_start_sequence_iter = sequence_iter
+                    print("reading sequences starting at sequence %d" % sequence_iter)
+
+                sequences.append(frames)
+                sequence_iter += 1
+                sequence_lengths_file.write("%d\n" % len(frames))
+
+                if (len(sequences) == sequences_per_file or
+                        (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))):
+                    output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
+                    output_fname = os.path.join(output_dir, output_fname)
+                    save_tf_record(output_fname, sequences)
+                    sequences[:] = []
+        sequence_lengths_file.close()
 
 
 def main():
@@ -141,12 +213,10 @@ def main():
     parser.add_argument("output_dir", type=str)
     parser.add_argument("image_size", type=int)
     args = parser.parse_args()
-
     partition_names = ['train', 'val', 'test']
     print("input dir", args.input_dir)
     partition_fnames = partition_data(args.input_dir)
     print("partiotion_fnames[0]", partition_fnames[0])
-
     for partition_name, partition_fnames in zip(partition_names, partition_fnames):
         partition_dir = os.path.join(args.output_dir, partition_name)
         if not os.path.exists(partition_dir):
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
index 5ef54d379dc796a786a52c1fc535432f079a4b43..3d2bd6d7462320baaa6bc8f5414506491d6d0e1f 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
@@ -1,118 +1,208 @@
-import argparse
-import sys
+
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong, Karim"
+__date__ = "2021-05-03"
+
+
+
 import glob
-import itertools
 import os
-import pickle
 import random
-import re
-import numpy as np
 import json
 import tensorflow as tf
 from tensorflow.contrib.training import HParams
-from mpi4py import MPI
 from collections import OrderedDict
-import matplotlib.pyplot as plt
-import matplotlib.gridspec as gridspec
-from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
-import data_preprocess.process_netCDF_v2 
-from general_utils import get_unique_vars
-from statistics import Calc_data_stat 
-from metadata import MetaData
-
-class MovingMnist(VarLenFeatureVideoDataset):
-    def __init__(self, *args, **kwargs):
-        super(MovingMnist, self).__init__(*args, **kwargs)
-        from google.protobuf.json_format import MessageToDict
+from google.protobuf.json_format import MessageToDict
+
+
+class MovingMnist(object):
+    def __init__(self, input_dir=None, datasplit_config=None, hparams_dict_config=None, mode="train",seed=None):
+        """
+        This class is used for preparing the data for moving mnist, and split the data to train/val/testing
+        :params input_dir: the path of tfrecords files 
+        :params datasplit_config: the path pointing to the datasplit_config json file
+        :params hparams_dict_config: the path to the dict that contains hparameters
+        :params mode: string, "train","val" or "test"
+        :params seed:int, the seed for dataset 
+        :return None
+        """
+        self.input_dir = input_dir
+        self.mode = mode 
+        self.seed = seed
+        if self.mode not in ('train', 'val', 'test'):
+            raise ValueError('Invalid mode %s' % self.mode)
+        if not os.path.exists(self.input_dir):
+            raise FileNotFoundError("input_dir %s does not exist" % self.input_dir)
+        self.datasplit_dict_path = datasplit_config
+        self.data_dict = self.get_datasplit()
+        self.hparams_dict_config = hparams_dict_config
+        self.hparams_dict = self.get_model_hparams_dict()
+        self.hparams = self.parse_hparams()
+        self.get_tfrecords_filename_base_datasplit()
+        self.get_example_info()
+
+
+
+    def get_datasplit(self):
+        """
+        Get the datasplit json file
+        """
+        with open(self.datasplit_dict_path) as f:
+            self.d = json.load(f)
+        return self.d
+
+
+
+    def get_model_hparams_dict(self):
+        """
+        Get model_hparams_dict from json file
+        """
+        self.model_hparams_dict_load = {}
+        if self.hparams_dict_config:
+            with open(self.hparams_dict_config) as f:
+                self.model_hparams_dict_load.update(json.loads(f.read()))
+        return self.model_hparams_dict_load
+
+                     
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
+
+    def get_default_hparams_dict(self):
+
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            sequence_length : the number of frames in the video sequence 
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+        :return:
+        """
+        hparams = dict(
+            context_frames=10,
+            sequence_length=20,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            loss_fun = "rmse",
+            shuffle_on_val= True,
+        )
+        return hparams
+
+
+    def get_tfrecords_filename_base_datasplit(self):
+       """
+       Get obsoluate .tfrecords names based on the data splits patterns
+       """
+       self.filenames = []
+       self.data_mode = self.data_dict[self.mode]
+       self.all_filenames = glob.glob(os.path.join(self.input_dir,"*.tfrecords"))
+       print("self.all_files",self.all_filenames)
+       for indice_group, index in self.data_mode.items():
+           fs = [MovingMnist.string_filter(max_value=index[1], min_value=index[0], string=s) for s in self.all_filenames]
+           print("fs:",fs)
+           self.tf_names = [self.all_filenames[fs_index] for fs_index in range(len(fs)) if fs[fs_index]==True]
+           print("tf_names,",self.tf_names)
+       # look for tfrecords in input_dir and input_dir/mode directories
+       for files in self.tf_names:
+            self.filenames.extend(glob.glob(os.path.join(self.input_dir, files)))
+       if self.filenames:
+           self.filenames = sorted(self.filenames)  # ensures order is the same across systems
+       if not self.filenames:
+           raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir)
+
+
+    @staticmethod
+    def string_filter(max_value=None, min_value=None, string="input_directory/sequence_index_0_index_10.tfrecords"):
+        a = os.path.split(string)[-1].split("_")
+        if not len(a) == 5:
+            raise ("The tfrecords pattern does not match the expected pattern, for instanct: 'sequence_index_0_to_10.tfrecords'") 
+        min_index = int(a[2])
+        max_index = int(a[4].split(".")[0])
+        if min_index >= min_value and max_index <= max_value:
+            return True
+        else:
+            return False
+
+    def get_example_info(self):
+        """
+         Get the data information from tfrecord file
+        """
         example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
         dict_message = MessageToDict(tf.train.Example.FromString(example))
         feature = dict_message['features']['feature']
         print("features in dataset:",feature.keys())
         self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
         self.image_shape = self.video_shape[1:]
-        self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
-
-    def get_default_hparams_dict(self):
-        default_hparams = super(MovingMnist, self).get_default_hparams_dict()
-        hparams = dict(
-            context_frames=10,#Bing: Todo oriignal is 10
-            sequence_length=20,#bing: TODO original is 20,
-            shuffle_on_val=True, 
-        )
-        return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-
-    @property
-    def jpeg_encoding(self):
-        return False
-
 
 
     def num_examples_per_epoch(self):
-        with open(os.path.join(self.input_dir, 'number_squences.txt'), 'r') as sequence_lengths_file:
-            sequence_lengths = sequence_lengths_file.readlines()
+        """
+        Calculate how many tfrecords samples in the train/val/test
+        """
+        #count how many tfrecords files for train/val/testing
+        len_fnames = len(self.filenames)
+        seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt')
+        with open(seq_len_file, 'r') as sequence_lengths_file:
+             sequence_lengths = sequence_lengths_file.readlines()
         sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
-        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
-
-    def filter(self, serialized_example):
-        return tf.convert_to_tensor(True)
-
-
-    def make_dataset_v2(self, batch_size):
+        self.num_examples_per_epoch  = len_fnames * sequence_lengths[0]
+        return self.num_examples_per_epoch
+
+
+    def make_dataset(self, batch_size):
+        """
+        Prepare batch_size dataset fed into to the models.
+        If the data are from training dataset,then the data is shuffled;
+        If the data are from val dataset, the shuffle var will be decided by the hparams.shuffled_on_val;
+        if the data are from test dataset, the data will not be shuffled
+        args:
+              batch_size: int, the size of samples fed into the models per iteration
+        """
+        self.num_epochs = self.hparams.max_epochs
         def parser(serialized_example):
             seqs = OrderedDict()
             keys_to_features = {
-                'width': tf.FixedLenFeature([], tf.int64),
-                'height': tf.FixedLenFeature([], tf.int64),
-                'sequence_length': tf.FixedLenFeature([], tf.int64),
-                'channels': tf.FixedLenFeature([], tf.int64),
-                'images/encoded': tf.VarLenFeature(tf.float32)
-            }
-            
-            # for i in range(20):
-            #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
+                 'width': tf.FixedLenFeature([], tf.int64),
+                 'height': tf.FixedLenFeature([], tf.int64),
+                 'sequence_length': tf.FixedLenFeature([], tf.int64),
+                 'channels': tf.FixedLenFeature([],tf.int64),
+                 'images/encoded': tf.VarLenFeature(tf.float32)
+             }
             parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
-            print ("Parse features", parsed_features)
             seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
-            #width = tf.sparse_tensor_to_dense(parsed_features["width"])
-           # height = tf.sparse_tensor_to_dense(parsed_features["height"])
-           # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
-           # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
-            images = []
             print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
             images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
             seqs["images"] = images
             return seqs
         filenames = self.filenames
-        print ("FILENAMES",filenames)
-	    #TODO:
-	    #temporal_filenames = self.temporal_filenames
         shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
         if shuffle:
             random.shuffle(filenames)
-        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)  # todo: what is buffer_size
-        print("files", self.filenames)
-        print("mode", self.mode)
-        dataset = dataset.filter(self.filter)
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)
         if shuffle:
-            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count=self.num_epochs))
         else:
             dataset = dataset.repeat(self.num_epochs)
-
+        if self.mode == "val": dataset = dataset.repeat(20)
         num_parallel_calls = None if shuffle else 1
         dataset = dataset.apply(tf.contrib.data.map_and_batch(
             parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
-        #dataset = dataset.map(parser)
-        # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
-        # dataset = dataset.apply(tf.contrib.data.map_and_batch(
-        #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
-        dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
+        dataset = dataset.prefetch(batch_size)
         return dataset
 
-
-
     def make_batch(self, batch_size):
-        dataset = self.make_dataset_v2(batch_size)
+        dataset = self.make_dataset(batch_size)
         iterator = dataset.make_one_shot_iterator()
         return iterator.get_next()
 
@@ -129,108 +219,8 @@ def _floats_feature(value):
 def _int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
-def save_tf_record(output_fname, sequences):
-    with tf.python_io.TFRecordWriter(output_fname) as writer:
-        for i in range(len(sequences)):
-            sequence = sequences[:,i,:,:,:] 
-            num_frames = len(sequence)
-            height, width = sequence[0,:,:,0].shape
-            encoded_sequence = np.array([list(image) for image in sequence])
-            features = tf.train.Features(feature={
-                'sequence_length': _int64_feature(num_frames),
-                'height': _int64_feature(height),
-                'width': _int64_feature(width),
-                'channels': _int64_feature(1),
-                'images/encoded': _floats_feature(encoded_sequence.flatten()),
-            })
-            example = tf.train.Example(features=features)
-            writer.write(example.SerializeToString())
-
-def read_frames_and_save_tf_records(output_dir,dat_npz, seq_length=20, sequences_per_file=128, height=64, width=64):#Bing: original 128
-    """
-    Read the moving_mnst data which is npz format, and save it to tfrecords files
-    The shape of dat_npz is [seq_length,number_samples,height,width]
-    moving_mnst only has one channel
-
-    """
-    os.makedirs(output_dir,exist_ok=True)
-    idx = 0
-    num_samples = dat_npz.shape[1]
-    dat_npz = np.expand_dims(dat_npz, axis=4) #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel]
-    print("data_npz_shape",dat_npz.shape)
-    dat_npz = dat_npz.astype(np.float32)
-    dat_npz /= 255.0 #normalize RGB codes by dividing it to the max RGB value 
-    while idx < num_samples - sequences_per_file:
-        sequences = dat_npz[:,idx:idx+sequences_per_file,:,:,:]
-        output_fname = 'sequence_{}_{}.tfrecords'.format(idx,idx+sequences_per_file)
-        output_fname = os.path.join(output_dir, output_fname)
-        save_tf_record(output_fname, sequences)
-        idx = idx + sequences_per_file
-    return None
-
-
-def write_sequence_file(output_dir,seq_length,sequences_per_file):    
-    partition_names = ["train","val","test"]
-    for partition_name in partition_names:
-        save_output_dir = os.path.join(output_dir,partition_name)
-        tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
-        print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
-        sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
-        for i in range(tfCounter*sequences_per_file):
-            sequence_lengths_file.write("%d\n" % seq_length)
-        sequence_lengths_file.close()
-
-
-def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"):
-    """
-    Plot the seq images 
-    """
-
-    if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)")
-    img_len = imgs.shape[0]
-    fig = plt.figure(figsize=(18,6))
-    gs = gridspec.GridSpec(1, 10)
-    gs.update(wspace = 0., hspace = 0.)
-    for i in range(img_len):
-        ax1 = plt.subplot(gs[i])
-        plt.imshow(imgs[i] ,cmap = 'jet')
-        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-    plt.savefig(os.path.join(output_png_dir, label + "_" +   str(idx) +  ".jpg"))
-    print("images_saved")
-    plt.clf()
+
 
 
     
-    
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
-    parser.add_argument("output_dir", type=str)
-    parser.add_argument("-sequences_per_file",type=int,default=2)
-    args = parser.parse_args()
-    current_path = os.getcwd()
-    data = np.load(os.path.join(args.input_dir,"mnist_test_seq.npy"))
-    print("data in minist_test_Seq shape",data.shape)
-    seq_length =  data.shape[0]
-    height = data.shape[2]
-    width = data.shape[3]
-    num_samples = data.shape[1] 
-    max_npz = np.max(data)
-    min_npz = np.min(data)
-    print("max_npz,",max_npz)
-    print("min_npz",min_npz)
-    #Todo need to discuss how to split the data, since we have totally 10000 samples, the origin paper convLSTM used 10000 as training, 2000 as validation and 3000 for testing
-    dat_train = data[:,:6000,:,:]
-    dat_val = data[:,6000:7000,:,:]
-    dat_test = data[:,7000:,:]
-    #plot_seq_imgs(dat_test[10:,0,:,:],output_png_dir="/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist/convLSTM",idx=1,label="Ground Truth from npz")
-    #save train
-    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"train"),dat_train, seq_length=20, sequences_per_file=40, height=height, width=width)
-    #save val
-    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"val"),dat_val, seq_length=20, sequences_per_file=40, height=height, width=width)
-    #save test     
-    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"test"),dat_test, seq_length=20, sequences_per_file=40, height=height, width=width)
-    #write_sequence_file(output_dir=args.output_dir,seq_length=20,sequences_per_file=40)
-if __name__ == '__main__':
-     main()