diff --git a/test/run_pytest.sh b/test/run_pytest.sh
index b440bebf7026eb777f0bd3aa82b7894f5c0e540c..83220d34a51379e93add931ae6e03e9491b5bce4 100644
--- a/test/run_pytest.sh
+++ b/test/run_pytest.sh
@@ -2,7 +2,7 @@
 
 # Name of virtual environment 
 #VIRT_ENV_NAME="vp_new_structure"
-VIRT_ENV_NAME="env_hdfml"
+VIRT_ENV_NAME="juwels_env"
 
 if [ -z ${VIRTUAL_ENV} ]; then
    if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then
@@ -21,8 +21,10 @@ fi
 #python -m pytest  test_prepare_era5_data.py 
 ##Test for preprocess_step1
 #python -m pytest  test_process_netCDF_v2.py
-
 source ../video_prediction_tools/env_setup/modules_train.sh
+##Test for preprocess moving mnist
+#python -m pytest test_prepare_moving_mnist_data.py
+python -m pytest test_train_moving_mnist_data.py 
 #Test for process step2
 #python -m pytest test_data_preprocess_step2.py
 #python -m pytest test_era5_data.py
@@ -31,5 +33,5 @@ source ../video_prediction_tools/env_setup/modules_train.sh
 #rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/* 
 #python -m pytest test_train_model_era5.py
 #python -m pytest test_vanilla_vae_model.py
-python -m pytest test_visualize_postprocess.py
+#python -m pytest test_visualize_postprocess.py
 #python -m pytest test_meta_postprocess.py
diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh
index 24d189278868d67853c74192793f9c13c3c7fb94..ed35fd8b68d2d8593c4e9ff411fd0c142b360204 100644
--- a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh
+++ b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh
@@ -4,11 +4,11 @@
 #SBATCH --ntasks=1
 ##SBATCH --ntasks-per-node=1
 #SBATCH --cpus-per-task=1
-#SBATCH --output=generate_era5-out.%j
-#SBATCH --error=generate_era5-err.%j
-#SBATCH --time=00:20:00
+#SBATCH --output=postprocess_era5-out.%j
+#SBATCH --error=postprocess_era5-err.%j
+#SBATCH --time=01:00:00
 #SBATCH --gres=gpu:1
-#SBATCH --partition=develgpus
+#SBATCH --partition=gpus
 #SBATCH --mail-type=ALL
 #SBATCH --mail-user=b.gong@fz-juelich.de
 ##jutil env activate -p cjjsc42
@@ -47,4 +47,4 @@ model=convLSTM
 srun python -u ../main_scripts/main_visualize_postprocess.py --checkpoint  ${checkpoint_dir} --mode test  \
                                                              --results_dir ${results_dir} --batch_size 4 \
                                                              --num_stochastic_samples 1  \
-                                                               > generate_era5-out.out
+                                                               > postprocess_era5-out_all.${SLURM_JOB_ID}
diff --git a/video_prediction_tools/data_preprocess/dataset_options.py b/video_prediction_tools/data_preprocess/dataset_options.py
index 28dffb6c8879bd934c6a8f7169ee0a6bcf679999..5e9729d693720e0e1380170a436980fdbeb900e7 100644
--- a/video_prediction_tools/data_preprocess/dataset_options.py
+++ b/video_prediction_tools/data_preprocess/dataset_options.py
@@ -16,4 +16,4 @@ def known_datasets():
         #        "era5_anomaly":"ERA5Dataset_v2_anomaly",
     }
 
-    return dataset_mappings
\ No newline at end of file
+    return dataset_mappings
diff --git a/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py b/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..444a6e0bdb0c11b25d19984236a71a3cefb9a2fa
--- /dev/null
+++ b/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py
@@ -0,0 +1,129 @@
+"""
+Class and functions required for preprocessing Moving mnist data from .npz to TFRecords
+"""
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong, Karim Mache"
+__date__ = "2021_05_04"
+
+
+import os
+import numpy as np
+import tensorflow as tf
+import argparse
+from model_modules.video_prediction.datasets.moving_mnist import MovingMnist
+
+
+class MovingMnist2Tfrecords(MovingMnist):
+
+    def __init__(self, input_dir=None, dest_dir=None, sequences_per_file=128):
+        """
+        This class is used for converting .npz files to tfrecords
+
+        :param input_dir: str, the path direcotry to the file of npz
+        :param dest_dir: the output  directory to save TFrecords.
+        :param sequence_length: int, default is 20, the sequence length per sample
+        :param sequences_per_file:int, how many sequences/samples per tfrecord to be saved
+        """
+        self.input_dir = input_dir
+        self.output_dir = dest_dir
+        os.makedirs(self.output_dir, exist_ok = True)
+        self.sequences_per_file = sequences_per_file
+        self.write_sequence_file()
+
+
+    def __call__(self):
+        """
+        steps to process npy file to tfrecords
+        :return: None
+        """
+        self.read_npz_file()
+        self.save_npz_to_tfrecords()
+
+    def read_npz_file(self):
+        self.data = np.load(os.path.join(self.input_dir, "mnist_test_seq.npy"))
+        print("data in minist_test_Seq shape", self.data.shape)
+        return None
+
+    def save_npz_to_tfrecords(self):  # Bing: original 128
+        """
+        Read the moving_mnst data which is npz format, and save it to tfrecords files
+        The shape of dat_npz is [seq_length,number_samples,height,width]
+        moving_mnst only has one channel
+        """
+        idx = 0
+        num_samples = self.data.shape[1]
+        if len(self.data.shape) == 4:
+            #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel]
+            self.data = np.expand_dims(self.data, axis = 4)
+        elif len(self.data.shape) == 5:
+            pass
+        else:
+            raise (f"The shape of input movning mnist npz file is {len(self.data.shape)} which is not either 4 or 5, please further check your data source!")
+
+        self.data = self.data.astype(np.float32)
+        self.data/= 255.0  # normalize RGB codes by dividing it to the max RGB value
+        while idx < num_samples - self.sequences_per_file:
+            sequences = self.data[:, idx:idx+self.sequences_per_file, :, :, :]
+            output_fname = 'sequence_index_{}_to_{}.tfrecords'.format(idx, idx + self.sequences_per_file-1)
+            output_fname = os.path.join(self.output_dir, output_fname)
+            MovingMnist2Tfrecords.save_tf_record(output_fname, sequences)
+            idx = idx + self.sequences_per_file
+        return None
+
+    @staticmethod
+    def save_tf_record(output_fname, sequences):
+        with tf.python_io.TFRecordWriter(output_fname) as writer:
+            for i in range(np.array(sequences).shape[1] - 1):
+                sequence = sequences[:, i, :, :, :]
+                num_frames = len(sequence)
+                height, width = sequence[0, :, :, 0].shape
+                encoded_sequence = np.array([list(image) for image in sequence])
+                features = tf.train.Features(feature = {
+                    'sequence_length': _int64_feature(num_frames),
+                    'height': _int64_feature(height),
+                    'width': _int64_feature(width),
+                    'channels': _int64_feature(1),
+                    'images/encoded': _floats_feature(encoded_sequence.flatten()),
+                })
+                example = tf.train.Example(features = features)
+                writer.write(example.SerializeToString())
+
+    def write_sequence_file(self):
+        """
+        Generate a txt file, with the numbers of sequences for each tfrecords file.
+        This is mainly used for calculting the number of samples for each epoch during training epoch
+        """
+
+        with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file:
+            seq_file.write("%d\n" % self.sequences_per_file)
+
+
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _bytes_list_feature(values):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+def _floats_feature(value):
+    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-input_dir", type=str, help="The input directory that contains the movning mnnist npz file", default="/p/largedata/datasets/moving-mnist/mnist_test_seq.npy")
+    parser.add_argument("-output_dir", type=str)
+    parser.add_argument("-sequences_per_file", type=int, default=2)
+    args = parser.parse_args()
+    inst = MovingMnist2Tfrecords(args.input_dir, args.output_dir, args.sequence_per_file)
+    inst()
+
+
+if __name__ == '__main__':
+     main()
diff --git a/video_prediction_tools/data_split/moving_mnist/datasplit.json b/video_prediction_tools/data_split/moving_mnist/datasplit.json
index 217b285d8e105debbe7841735eb50786762ace19..0c199e18b6685404b1e137a139985f1b511bc4c4 100644
--- a/video_prediction_tools/data_split/moving_mnist/datasplit.json
+++ b/video_prediction_tools/data_split/moving_mnist/datasplit.json
@@ -1,11 +1,10 @@
 {
     "train":{ 
-             "index1":[0,100],
-	     "index2":[150,200]
+             "index1":[0,99]
              },
      "val":
              {
-             "index1":[110,149]
+             "index1":[100,149]
              },
       "test":
              {
diff --git a/video_prediction_tools/data_split/moving_mnist/datasplit_template.json b/video_prediction_tools/data_split/moving_mnist/datasplit_template.json
index 11407a0439e7bd3d1397d6dfce9cce660786a866..890b7e4599d429a0ee91fd2ebf79ecf345168dda 100644
--- a/video_prediction_tools/data_split/moving_mnist/datasplit_template.json
+++ b/video_prediction_tools/data_split/moving_mnist/datasplit_template.json
@@ -7,12 +7,11 @@
 #              Be aware that this is a prue data file, i.e. do not make use of any Python-functions such as np.range or similar here!
 {
     "train":{ 
-             "index1":[0,100],
-	     "index2":[150,200]
+             "index1":[0,100]
              },
      "val":
              {
-             "index1":[110,149]
+             "index1":[100,149]
              },
       "test":
              {
diff --git a/video_prediction_tools/env_setup/create_env.sh b/video_prediction_tools/env_setup/create_env.sh
index 1a5ce05448b3fa2b0b44ce002165e0ebe4b1c94f..31f2065e9c7475762f654bcb878fcd5de7421ba4 100755
--- a/video_prediction_tools/env_setup/create_env.sh
+++ b/video_prediction_tools/env_setup/create_env.sh
@@ -63,27 +63,27 @@ else
   ENV_EXIST=0
 fi
 
-# add personal email-address to Batch-scripts
-if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == *juwels* ]]; then
-  if [[ "${HOST_NAME}" == jwlogin2[1-4]* ]]; then  
-    # on Juwels Booster, we are in a container environment -> loading modules is not possible	  
-    echo "***** Note for Juwels Booster! *****"
-    echo "Already checked the required modules?"
-    echo "To do so, run 'source modules_train.sh' after exiting the singularity."
-    echo "***** Note for Juwels Booster! *****"
-  else
-    # load modules and check for their availability
-    echo "***** Checking modules required during the workflow... *****"
-    source ${ENV_SETUP_DIR}/modules_preprocess.sh purge
-    source ${ENV_SETUP_DIR}/modules_train.sh purge
-    source ${ENV_SETUP_DIR}/modules_postprocess.sh
-  fi
-else 
-  # unset PYTHONPATH on every other machine that is not a known HPC-system	
-  unset PYTHONPATH
-fi
-
+# Create fresh virtual environment or just activate the existing one
 if [[ "$ENV_EXIST" == 0 ]]; then
+  # Check modules first
+  if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == *juwels* ]]; then
+    if [[ "${HOST_NAME}" == jwlogin2[1-4]* ]]; then  
+      # on Juwels Booster, we are in a container environment -> loading modules is not possible	  
+      echo "***** Note for Juwels Booster! *****"
+      echo "Already checked the required modules?"
+      echo "To do so, run 'source modules_train.sh' after exiting the singularity."
+      echo "***** Note for Juwels Booster! *****"
+    else
+      # load modules and check for their availability
+      echo "***** Checking modules required during the workflow... *****"
+      source ${ENV_SETUP_DIR}/modules_preprocess.sh purge
+      source ${ENV_SETUP_DIR}/modules_train.sh purge
+      source ${ENV_SETUP_DIR}/modules_postprocess.sh
+    fi
+  else 
+    # unset PYTHONPATH on every other machine that is not a known HPC-system	
+    unset PYTHONPATH
+  fi
   # Activate virtual environment and install additional Python packages.
   echo "Configuring and activating virtual environment on ${HOST_NAME}"
     
@@ -143,7 +143,8 @@ if [[ "$ENV_EXIST" == 0 ]]; then
   fi
   info_str="Virtual environment ${ENV_DIR} has been set up successfully."
 elif [[ "$ENV_EXIST" == 1 ]]; then
-  # activating virtual env is suifficient
+  # loading modules of postprocessing and activating virtual env are suifficient
+  source ${ENV_SETUP_DIR}/modules_postprocess.sh
   source ${ENV_DIR}/bin/activate
   info_str="Virtual environment ${ENV_DIR} has been activated successfully."
 fi
@@ -154,10 +155,10 @@ source "${WORKING_DIR}"/utils/runscript_generator/setup_runscript_templates.sh
 
 echo "******************************************** NOTE ********************************************"
 echo "${info_str}"
-echo "Make use of config_runscript.py to generate customized runscripts of the workflow steps."
+echo "Make use of generate_runscript.py to generate customized runscripts of the workflow steps."
 echo "******************************************** NOTE ********************************************"
 
 # finally clean up loaded modules (if we are not on Juwels)
-if [[ "${HOST_NAME}" == *hdfml* || "${HOST_NAME}" == *juwels* ]] && [[ "${HOST_NAME}" != jwlogin2[1-4]* ]]; then
-  module --force purge
-fi
+#if [[ "${HOST_NAME}" == *hdfml* || "${HOST_NAME}" == *juwels* ]] && [[ "${HOST_NAME}" != jwlogin2[1-4]* ]]; then
+#  module --force purge
+#fi
diff --git a/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json
index 17783b5f68c1974ecb12c43948c14fcba77acd8e..878c29a0553ddb74a563299a7d3ec5683469194c 100644
--- a/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json
+++ b/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json
@@ -4,10 +4,8 @@
     "lr": 0.001,
     "max_epochs":20,
     "context_frames":10,
-    "sequence_length":20,
     "loss_fun":"rmse",
     "shuffle_on_val":false
-
 }
 
 
diff --git a/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json
new file mode 100644
index 0000000000000000000000000000000000000000..a2b9be547d450d49ef230e37821f503838a5dcee
--- /dev/null
+++ b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json
@@ -0,0 +1,14 @@
+
+{
+    "batch_size": 4,
+    "lr": 0.001,
+    "max_epochs":20,
+    "context_frames":12,
+    "loss_fun":"rmse",
+    "shuffle_on_val":false,
+    "recon_weight":0.6
+
+}
+
+
+
diff --git a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json b/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json
index c2edaad9f9ac158f6e7b8d94bb81db16d55d05e8..0b3788d726fc91e3d5c1aec98166259a2e0012e9 100644
--- a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json
+++ b/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json
@@ -3,9 +3,7 @@
     "batch_size": 10,
     "lr": 0.001,
     "max_epochs":2,
-    "context_frames":10,
-    "sequence_length":20
-
+    "context_frames":10
 }
 
 
diff --git a/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json
index c19ecf6d3b565268efb5f74e14943f32a19519b4..0ccf44e6370f765857204317f172c866865b4b35 100644
--- a/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json
+++ b/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json
@@ -13,6 +13,5 @@
     "state_weight": 0.0,
     "nz": 32,
     "max_epochs":2,
-    "context_frames":12,
-    "sequence_length":24
+    "context_frames":12
 }
diff --git a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json
index 0acefc42a13583d13ac1c263ea44593a6d5e17d0..8e96727e95f761efc170365abd4e8af89696c168 100644
--- a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json
+++ b/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json
@@ -11,7 +11,5 @@
     "state_weight": 0.0,
     "nz": 32,
     "max_epochs":2,
-    "context_frames":10,
-    "sequence_length":20
-
+    "context_frames":10
 }
diff --git a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json
index d7058c6b2534d46cd6e08672d33a76bfcc4c7a35..d182658a2161a0405d9fb92d9677fc34bd39251f 100644
--- a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json
+++ b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json
@@ -13,8 +13,7 @@
     "state_weight": 0.0,
     "nz": 16,
     "max_epochs":2,
-    "context_frames":10,
-    "sequence_length":20
+    "context_frames":10
 }
 
 
diff --git a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json b/video_prediction_tools/hparams/era5/vae/model_hparams_template.json
index 1afb4d4391421b34111c252f4775d448be7675d0..2dcecd346b9b4adc4f3179020d0ee83b8512c6a0 100644
--- a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json
+++ b/video_prediction_tools/hparams/era5/vae/model_hparams_template.json
@@ -5,7 +5,6 @@
     "nz":16,
     "max_epochs":2,
     "context_frames":10,
-    "sequence_length":20,
     "weight_recon":1,
     "loss_fun": "rmse",
     "shuffle_on_val": false
diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json
index b59f6cb2ee96162b2eb6014d7ca6bd37f54d4218..6cda5552d437b9b283a40bdde77eb1d3b3497b36 100644
--- a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json
+++ b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json
@@ -4,7 +4,6 @@
     "lr": 0.001,
     "max_epochs":20,
     "context_frames":10,
-    "sequence_length":20,
     "loss_fun":"cross_entropy"
 }
 
diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json
new file mode 100644
index 0000000000000000000000000000000000000000..6cda5552d437b9b283a40bdde77eb1d3b3497b36
--- /dev/null
+++ b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json
@@ -0,0 +1,11 @@
+
+{
+    "batch_size": 10,
+    "lr": 0.001,
+    "max_epochs":20,
+    "context_frames":10,
+    "loss_fun":"cross_entropy"
+}
+
+
+
diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 053d7d7060529b87a290501e86e1a16d5e9cc4a6..f6fd6e78db9c49f433d679c03c97463a654192d7 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -136,23 +136,22 @@ class TrainModel(object):
                 
     def setup_dataset(self):
         """
-        Setup train and val dataset instance with the corresponding data split configuration
+        Setup train and val dataset instance with the corresponding data split configuration.
+        Simultaneously, sequence_length is attached to the hyperparameter dictionary.
         """
         VideoDataset = datasets.get_dataset_class(self.dataset)
         self.train_dataset = VideoDataset(input_dir=self.input_dir,mode='train',datasplit_config=self.datasplit_dict)
         self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val',datasplit_config=self.datasplit_dict)
-        #self.variable_scope = tf.get_variable_scope()
-        #self.variable_scope.set_use_resource(True)
-      
+
+        self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
 
     def setup_model(self):
         """
         Set up model instance for the given model names
         """
         VideoPredictionModel = models.get_model_class(self.model)
-        self.video_model = VideoPredictionModel(
-                                    hparams_dict=self.model_hparams_dict_load,
-                                       )
+        self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict_load)
+
     def setup_graph(self):
         """
         build model graph
@@ -178,8 +177,8 @@ class TrainModel(object):
         self.inputs = self.iterator.get_next()
         #since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model,
         # otherwise the model will raise error
-        if self.dataset == "era5" and self.model == "savp":
-           del self.inputs["T_start"]
+        #if self.dataset == "era5" and self.model == "savp":
+        #   del self.inputs["T_start"]
 
 
 
@@ -232,6 +231,7 @@ class TrainModel(object):
         self.num_examples = self.train_dataset.num_examples_per_epoch()
         self.steps_per_epoch = int(self.num_examples/batch_size)
         self.total_steps = self.steps_per_epoch * max_epochs
+        print("Batch size is {} ; max_epochs is {}; num_samples per epoch is {}; steps_per_epoch is {}, total steps is {}".format(batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps))
 
     def restore(self,sess, checkpoints, restore_to_checkpoint_mapping=None):
         """
@@ -294,11 +294,15 @@ class TrainModel(object):
                 self.create_fetches_for_train()             # In addition to the loss, we fetch the optimizer
                 self.results = sess.run(self.fetches)       # ...and run it here!
                 train_losses.append(self.results["total_loss"])
+                print("t_start for training",self.results["inputs"]["T_start"])
+                print("len of t_start per iteration",len(self.results["inputs"]["T_start"]))
                 #Run and fetch losses for validation data
                 val_handle_eval = sess.run(self.val_handle)
                 self.create_fetches_for_val()
                 self.val_results = sess.run(self.val_fetches,feed_dict={self.train_handle: val_handle_eval})
                 val_losses.append(self.val_results["total_loss"])
+                print("t_start for validation",self.val_results["inputs"]["T_start"])
+                print("len of t_start per iteration",len(self.val_results["inputs"]["T_start"]))
                 self.write_to_summary()
                 self.print_results(step,self.results)
                 timeit_end = time.time()
@@ -335,6 +339,8 @@ class TrainModel(object):
        if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": self.fetches_for_train_convLSTM()
        if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.fetches_for_train_savp()
        if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": self.fetches_for_train_vae()
+       if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":self.fetches_for_train_gan()
+       if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":self.fetches_for_train_convLSTM()
        return self.fetches     
     
     def fetches_for_train_convLSTM(self):
@@ -342,8 +348,7 @@ class TrainModel(object):
         Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users
         """
         self.fetches["total_loss"] = self.video_model.total_loss
- 
-
+        self.fetches["inputs"] = self.video_model.inputs
 
  
     def fetches_for_train_savp(self):
@@ -355,7 +360,7 @@ class TrainModel(object):
         self.fetches["d_loss"] = self.video_model.d_loss
         self.fetches["g_loss"] = self.video_model.g_loss
         self.fetches["total_loss"] = self.video_model.g_loss
-
+        self.fetches["inputs"] = self.video_model.inputs
 
 
     def fetches_for_train_mcnet(self):
@@ -374,15 +379,19 @@ class TrainModel(object):
         self.fetches["recon_loss"] = self.video_model.recon_loss
         self.fetches["total_loss"] = self.video_model.total_loss
 
+    def fetches_for_train_gan(self):
+        self.fetches["total_loss"] = self.video_model.total_loss
+
     def create_fetches_for_val(self):
         """
         Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users
         """
         if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
             self.val_fetches = {"total_loss": self.video_model.g_loss}
+            self.val_fetches["inputs"] = self.video_model.inputs
         else:
             self.val_fetches = {"total_loss": self.video_model.total_loss}
-        
+            self.val_fetches["inputs"] = self.video_model.inputs
         self.val_fetches["summary"] = self.video_model.summary_op
 
     def write_to_summary(self):
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index 196cfee6508ec298dbb390941d787cb6957022c9..60dfef0032a7746d57734285a7b86328e0c74f50 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -422,19 +422,20 @@ class Postprocess(TrainModel):
             # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel]
             feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()}
             gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
+
             # sanity check on length of forecast sequence
             assert gen_images.shape[1] == self.sequence_length - 1, \
                 "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(method)
             # denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method)
             gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls,
                                                                 norm_method="minmax")
-            # store data into datset
+            # store data into datset and get number of samples (may differ from batch_size at the end of the test dataset)
             times_0, init_times = self.get_init_time(t_starts)
             batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times)
-            # auxilary list of forecast dimensions
-            dims_fcst = list(batch_ds["{0}_ref".format(self.vars_in[0])].dims)
+            nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind)
+            batch_ds = batch_ds.isel(init_time=slice(0, nbs))
 
-            for i in np.arange(self.batch_size):
+            for i in np.arange(nbs):
                 # work-around to make use of get_persistence_forecast_per_sample-method
                 times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime()
                 # get persistence forecast for sequences at hand and write to dataset
@@ -541,9 +542,10 @@ class Postprocess(TrainModel):
                                       .format(method, ", ".join(misses)))
 
         varname_ref = "{0}_ref".format(varname)
-        # reset init-time coordinate of metric_ds in place
+        # reset init-time coordinate of metric_ds in place and get indices for slicing
+        ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch)
         init_times_metric = metric_ds["init_time"].values
-        init_times_metric[ind_start:ind_start+self.batch_size] = data_ds["init_time"]
+        init_times_metric[ind_start:ind_end] = data_ds["init_time"]
         metric_ds = metric_ds.assign_coords(init_time=init_times_metric)
         # populate metric_ds
         for fcst_prod in self.fcst_products.keys():
diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py
index ca602a954a107c9217942e7f01e4eae4c68d58bb..5836ab9fce48692252a4dbc44415b4a4f9e2c2c3 100644
--- a/video_prediction_tools/model_modules/model_architectures.py
+++ b/video_prediction_tools/model_modules/model_architectures.py
@@ -14,6 +14,8 @@ def known_models():
         'vae': 'VanillaVAEVideoPredictionModel',
         'convLSTM': 'VanillaConvLstmVideoPredictionModel',
         'mcnet': 'McNetVideoPredictionModel',
+        'gan': "VanillaGANVideoPredictionModel",
+        'convLSTM_gan': "ConvLstmGANVideoPredictionModel",
         'ours_vae_l1': 'SAVPVideoPredictionModel',
         'ours_gan': 'SAVPVideoPredictionModel',
     }
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
index 7a70351e7808103e9a3e02e65654f151213c45ec..cd0ec2b230169016cc10aee5ee2ff3d7e4fc611b 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
@@ -18,13 +18,12 @@ def get_dataset_class(dataset):
     if dataset_class is None:
         raise ValueError('Invalid dataset %s' % dataset)
     else:
-        # ERA5Dataset does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which
-        # does not need to be a subclass of BaseVideoDataset
-        if not dataset_class == "ERA5Dataset":
-            dataset_class = globals().get(dataset_class)
-            if not issubclass(dataset_class,BaseVideoDataset):
-                raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class))
-        else:
-            dataset_class = globals().get(dataset_class)
+        # ERA5Dataset  movning_mnist does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which does not need to be a subclass of BaseVideoDataset
+        #if not dataset_class == "ERA5Dataset" or not dataset_class == "MovingMnist":
+        #    dataset_class = globals().get(dataset_class)
+        #    if not issubclass(dataset_class,BaseVideoDataset):
+        #        raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class))
+        #else:
+        dataset_class = globals().get(dataset_class)
 
     return dataset_class
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index 9bd0362541b858b7ceb8650057d72f17e0188e8c..ce62965a2c92432ffbf739e933279f91b69e355c 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -29,6 +29,7 @@ class ERA5Dataset(object):
         self.datasplit_config = datasplit_config
         self.mode = mode
         self.seed = seed
+        self.sequence_length = None                             # will be set in get_example_info
         if self.mode not in ('train', 'val', 'test'):
             raise ValueError('Invalid mode %s' % self.mode)
         if not os.path.exists(self.input_dir):
@@ -61,14 +62,12 @@ class ERA5Dataset(object):
         Returns:
             A dict with the following hyperparameters.
             context_frames  : the number of ground-truth frames to pass in at start.
-            sequence_length : the number of frames in the video sequence 
             max_epochs      : the number of epochs to train model
             lr              : learning rate
             loss_fun        : the loss function
         """
         hparams = dict(
             context_frames=10,
-            sequence_length=20,
             max_epochs = 20,
             batch_size = 40,
             lr = 0.001,
@@ -116,26 +115,28 @@ class ERA5Dataset(object):
 
     def get_example_info(self):
         """
-         Get the data information from tfrecord file
+        Get the data information from an example tfrecord file
         """
         example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
         dict_message = MessageToDict(tf.train.Example.FromString(example))
         feature = dict_message['features']['feature']
         print("features in dataset:",feature.keys())
-        self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
-        self.image_shape = self.video_shape[1:]
+        video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length', 'height',
+                                                                                   'width', 'channels'])
+        self.sequence_length = video_shape[0]
+        self.image_shape = video_shape[1:]
  
     def num_examples_per_epoch(self):
         """
         Calculate how many tfrecords samples in the train/val/test 
         """
-        #count how many tfrecords files for train/val/testing
+        # count how many tfrecords files for train/val/testing
         len_fnames = len(self.filenames)
-        seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt')
-        with open(seq_len_file, 'r') as sequence_lengths_file:
-             sequence_lengths = sequence_lengths_file.readlines()
-        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
-        self.num_examples_per_epoch  = len_fnames * sequence_lengths[0]
+        num_seq_file = os.path.join(self.input_dir, 'number_sequences.txt')
+        with open(num_seq_file, 'r') as dfile:
+             num_seqs = dfile.readlines()
+        num_sequences = [int(num_seq.strip()) for num_seq in num_seqs]
+        self.num_examples_per_epoch  = len_fnames * num_sequences[0]
         return self.num_examples_per_epoch 
 
 
@@ -163,9 +164,10 @@ class ERA5Dataset(object):
             parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
             seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
             T_start = tf.sparse_tensor_to_dense(parsed_features["t_start"])
-            images = []
-            print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
-            images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
+            print("Image shape {}, {},{},{}".format(self.sequence_length, self.image_shape[0], self.image_shape[1],
+                                                    self.image_shape[2]))
+            images = tf.reshape(seq, [self.sequence_length,self.image_shape[0],self.image_shape[1],
+                                      self.image_shape[2]], name="reshape_new")
             seqs["images"] = images
             seqs["T_start"] = T_start
             return seqs
@@ -179,7 +181,9 @@ class ERA5Dataset(object):
             dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
         else:
             dataset = dataset.repeat(self.num_epochs)
-        if self.mode == "val": dataset = dataset.repeat(20)
+
+        if self.mode == "val": dataset = dataset.repeat(20) 
+
         num_parallel_calls = None if shuffle else 1
         dataset = dataset.apply(tf.contrib.data.map_and_batch(
             parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
index b1136172b203218b5bbc3b052e0d454c2c5bd60f..40df33aaaf82d4764d8b0d55c8a1bed55131e963 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
@@ -8,41 +8,117 @@ import re
 import tensorflow as tf
 import numpy as np
 import skimage.io
-from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
+from collections import OrderedDict
+from tensorflow.contrib.training import HParams
+from google.protobuf.json_format import MessageToDict
+
+
+class KTHVideoDataset(object):
+    def __init__(self,input_dir=None,datasplit_config=None,hparams_dict_config=None, mode='train',seed=None):
+        """
+        This class is used for preparing data for training/validation and test models
+        args:
+            input_dir            : the path of tfrecords files
+            datasplit_config     : the path pointing to the datasplit_config json file
+            hparams_dict_config  : the path to the dict that contains hparameters,
+            mode                 : string, "train","val" or "test"
+            seed                 : int, the seed for dataset 
+        """
+        self.input_dir = input_dir
+        self.datasplit_config = datasplit_config
+        self.mode = mode
+        self.seed = seed
+        if self.mode not in ('train', 'val', 'test'):
+            raise ValueError('Invalid mode %s' % self.mode)
+        if not os.path.exists(self.input_dir):
+            raise FileNotFoundError("input_dir %s does not exist" % self.input_dir)
+        self.datasplit_dict_path = datasplit_config
+        self.data_dict = self.get_datasplit()
+        self.hparams_dict_config = hparams_dict_config
+        self.hparams_dict = self.get_model_hparams_dict()
+        self.hparams = self.parse_hparams() 
+        self.get_tfrecords_filesnames_base_datasplit()
+        self.get_example_info()
+
+
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
 
 
-class KTHVideoDataset(VarLenFeatureVideoDataset):
-    def __init__(self, *args, **kwargs):
-        super(KTHVideoDataset, self).__init__(*args, **kwargs)
-        from google.protobuf.json_format import MessageToDict
-        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
-        dict_message = MessageToDict(tf.train.Example.FromString(example))
-        feature = dict_message['features']['feature']
-        image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels'])
-
-        self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape
-
     def get_default_hparams_dict(self):
-        default_hparams = super(KTHVideoDataset, self).get_default_hparams_dict()
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            sequence_length : the number of frames in the video sequence 
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+        """
         hparams = dict(
             context_frames=10,
             sequence_length=20,
-            long_sequence_length=40,
-            force_time_shift=True,
-            shuffle_on_val=True,
-            use_state=False,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            loss_fun = "rmse",
+            shuffle_on_val= True,
         )
-        return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-    @property
-    def jpeg_encoding(self):
-        return False
+        return hparams
+
+
+
+
+    def get_datasplit(self):
+        """
+        Get the datasplit json file
+        """
+
+        with open(self.datasplit_dict_path) as f:
+            self.d = json.load(f)
+        return self.d
+
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+      
+    def get_tfrecords_filesnames_base_datasplit(self):
+        """
+        Get  absolute .tfrecord path names based on the data splits patterns
+        """
+        self.filenames = []
+        self.data_mode = self.data_dict[self.mode]
+        self.tf_names = []
+        for year, months in self.data_mode.items():
+            for month in months:
+                tf_files = "sequence_Y_{}_M_{}_*_to_*.tfrecord*".format(year,month)    
+                self.tf_names.append(tf_files)
+        # look for tfrecords in input_dir and input_dir/mode directories
+        for files in self.tf_names:
+            self.filenames.extend(glob.glob(os.path.join(self.input_dir, files)))
+        if self.filenames:
+            self.filenames = sorted(self.filenames)  # ensures order is the same across systems
+        if not self.filenames:
+            raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir)
 
     def num_examples_per_epoch(self):
-        with open(os.path.join(self.input_dir, 'number_sequences.txt'), 'r') as sequence_lengths_file:
-            sequence_lengths = sequence_lengths_file.readlines()
+        """
+        Calculate how many tfrecords samples in the train/val/test 
+        """
+        #count how many tfrecords files for train/val/testing
+        len_fnames = len(self.filenames)
+        seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt')
+        with open(seq_len_file, 'r') as sequence_lengths_file:
+             sequence_lengths = sequence_lengths_file.readlines()
         sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
-        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
+        self.num_examples_per_epoch  = len_fnames * sequence_lengths[0]
+        return self.num_examples_per_epoch 
 
 
 def _bytes_feature(value):
@@ -62,17 +138,12 @@ def partition_data(input_dir):
     fnames = glob.glob(os.path.join(input_dir, '*/*'))
     fnames = [fname for fname in fnames if os.path.isdir(fname)]
     print("frames",fnames[0])
-
     persons = [re.match('person(\d+)_\w+_\w+', os.path.split(fname)[1]).group(1) for fname in fnames]
     persons = np.array([int(person) for person in persons])
-
     train_mask = persons <= 16
-
     train_fnames = [fnames[i] for i in np.where(train_mask)[0]]
     test_fnames = [fnames[i] for i in np.where(~train_mask)[0]]
-
     random.shuffle(train_fnames)
-
     pivot = int(0.95 * len(train_fnames))
     train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:]
     return train_fnames, val_fnames, test_fnames
@@ -96,41 +167,42 @@ def save_tf_record(output_fname, sequences):
             writer.write(example.SerializeToString())
 
 
-def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128):
-    partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test
-    sequences = []
-    sequence_iter = 0
-    sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
-    for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person
-        meta_partition_name = partition_name if partition_name == 'test' else 'train'
-        meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' %
-                                  (meta_partition_name, image_size, image_size))
-        with open(meta_fname, "rb") as f:
-            data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys.  "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png
-
-        vid = os.path.split(video_dir)[1]
-        (d,) = [d for d in data if d['vid'] == vid]
-        for frame_fnames_iter, frame_fnames in enumerate(d['files']):
-            frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames]
-            frames = skimage.io.imread_collection(frame_fnames)
-            # they are grayscale images, so just keep one of the channels
-            frames = [frame[..., 0:1] for frame in frames]
-
-            if not sequences: #The length of the sequence in sequences could be different
-                last_start_sequence_iter = sequence_iter
-                print("reading sequences starting at sequence %d" % sequence_iter)
-
-            sequences.append(frames)
-            sequence_iter += 1
-            sequence_lengths_file.write("%d\n" % len(frames))
-
-            if (len(sequences) == sequences_per_file or
-                    (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))):
-                output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
-                output_fname = os.path.join(output_dir, output_fname)
-                save_tf_record(output_fname, sequences)
-                sequences[:] = []
-    sequence_lengths_file.close()
+
+    def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128):
+        partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test
+        sequences = []
+        sequence_iter = 0
+        sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w')
+        for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person
+            meta_partition_name = partition_name if partition_name == 'test' else 'train'
+            meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' %
+                                      (meta_partition_name, image_size, image_size))
+            with open(meta_fname, "rb") as f:
+                data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys.  "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png
+
+            vid = os.path.split(video_dir)[1]
+            (d,) = [d for d in data if d['vid'] == vid]
+            for frame_fnames_iter, frame_fnames in enumerate(d['files']):
+                frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames]
+                frames = skimage.io.imread_collection(frame_fnames)
+                # they are grayscale images, so just keep one of the channels
+                frames = [frame[..., 0:1] for frame in frames]
+
+                if not sequences: #The length of the sequence in sequences could be different
+                    last_start_sequence_iter = sequence_iter
+                    print("reading sequences starting at sequence %d" % sequence_iter)
+
+                sequences.append(frames)
+                sequence_iter += 1
+                sequence_lengths_file.write("%d\n" % len(frames))
+
+                if (len(sequences) == sequences_per_file or
+                        (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))):
+                    output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1)
+                    output_fname = os.path.join(output_dir, output_fname)
+                    save_tf_record(output_fname, sequences)
+                    sequences[:] = []
+        sequence_lengths_file.close()
 
 
 def main():
@@ -141,12 +213,10 @@ def main():
     parser.add_argument("output_dir", type=str)
     parser.add_argument("image_size", type=int)
     args = parser.parse_args()
-
     partition_names = ['train', 'val', 'test']
     print("input dir", args.input_dir)
     partition_fnames = partition_data(args.input_dir)
     print("partiotion_fnames[0]", partition_fnames[0])
-
     for partition_name, partition_fnames in zip(partition_names, partition_fnames):
         partition_dir = os.path.join(args.output_dir, partition_name)
         if not os.path.exists(partition_dir):
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
index 5ef54d379dc796a786a52c1fc535432f079a4b43..adfdf06539174a78be375fa1f2416dade078a1c5 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
@@ -1,118 +1,209 @@
-import argparse
-import sys
+
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong, Karim"
+__date__ = "2021-05-03"
+
+
+
 import glob
-import itertools
 import os
-import pickle
 import random
-import re
-import numpy as np
 import json
 import tensorflow as tf
 from tensorflow.contrib.training import HParams
-from mpi4py import MPI
 from collections import OrderedDict
-import matplotlib.pyplot as plt
-import matplotlib.gridspec as gridspec
-from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
-import data_preprocess.process_netCDF_v2 
-from general_utils import get_unique_vars
-from statistics import Calc_data_stat 
-from metadata import MetaData
-
-class MovingMnist(VarLenFeatureVideoDataset):
-    def __init__(self, *args, **kwargs):
-        super(MovingMnist, self).__init__(*args, **kwargs)
-        from google.protobuf.json_format import MessageToDict
-        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
-        dict_message = MessageToDict(tf.train.Example.FromString(example))
-        feature = dict_message['features']['feature']
-        print("features in dataset:",feature.keys())
-        self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels'])
-        self.image_shape = self.video_shape[1:]
-        self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape
+from google.protobuf.json_format import MessageToDict
+
+
+class MovingMnist(object):
+    def __init__(self, input_dir=None, datasplit_config=None, hparams_dict_config=None, mode="train",seed=None):
+        """
+        This class is used for preparing the data for moving mnist, and split the data to train/val/testing
+        :params input_dir: the path of tfrecords files 
+        :params datasplit_config: the path pointing to the datasplit_config json file
+        :params hparams_dict_config: the path to the dict that contains hparameters
+        :params mode: string, "train","val" or "test"
+        :params seed:int, the seed for dataset 
+        :return None
+        """
+        self.input_dir = input_dir
+        self.mode = mode 
+        self.seed = seed
+        self.sequence_length = None                             # will be set in get_example_info
+        if self.mode not in ('train', 'val', 'test'):
+            raise ValueError('Invalid mode %s' % self.mode)
+        if not os.path.exists(self.input_dir):
+            raise FileNotFoundError("input_dir %s does not exist" % self.input_dir)
+        self.datasplit_dict_path = datasplit_config
+        self.data_dict = self.get_datasplit()
+        self.hparams_dict_config = hparams_dict_config
+        self.hparams_dict = self.get_model_hparams_dict()
+        self.hparams = self.parse_hparams()
+        self.get_tfrecords_filename_base_datasplit()
+        self.get_example_info()
+
+    def get_datasplit(self):
+        """
+        Get the datasplit json file
+        """
+        with open(self.datasplit_dict_path) as f:
+            self.d = json.load(f)
+        return self.d
+
+    def get_model_hparams_dict(self):
+        """
+        Get model_hparams_dict from json file
+        """
+        self.model_hparams_dict_load = {}
+        if self.hparams_dict_config:
+            with open(self.hparams_dict_config) as f:
+                self.model_hparams_dict_load.update(json.loads(f.read()))
+        return self.model_hparams_dict_load
+
+                     
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
 
     def get_default_hparams_dict(self):
-        default_hparams = super(MovingMnist, self).get_default_hparams_dict()
+
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            sequence_length : the number of frames in the video sequence 
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+        :return:
+        """
         hparams = dict(
-            context_frames=10,#Bing: Todo oriignal is 10
-            sequence_length=20,#bing: TODO original is 20,
-            shuffle_on_val=True, 
+            context_frames=10,
+            sequence_length=20,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            loss_fun = "rmse",
+            shuffle_on_val= True,
         )
-        return dict(itertools.chain(default_hparams.items(), hparams.items()))
-
-
-    @property
-    def jpeg_encoding(self):
-        return False
-
+        return hparams
+
+
+    def get_tfrecords_filename_base_datasplit(self):
+       """
+       Get obsoluate .tfrecords names based on the data splits patterns
+       """
+       self.filenames = []
+       self.data_mode = self.data_dict[self.mode]
+       self.all_filenames = glob.glob(os.path.join(self.input_dir,"*.tfrecords"))
+       print("self.all_files",self.all_filenames)
+       for indice_group, index in self.data_mode.items():
+           fs = [MovingMnist.string_filter(max_value=index[1], min_value=index[0], string=s) for s in self.all_filenames]
+           print("fs:",fs)
+           self.tf_names = [self.all_filenames[fs_index] for fs_index in range(len(fs)) if fs[fs_index]==True]
+           print("tf_names,",self.tf_names)
+       # look for tfrecords in input_dir and input_dir/mode directories
+       for files in self.tf_names:
+            self.filenames.extend(glob.glob(os.path.join(self.input_dir, files)))
+       if self.filenames:
+           self.filenames = sorted(self.filenames)  # ensures order is the same across systems
+       if not self.filenames:
+           raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir)
+
+
+    @staticmethod
+    def string_filter(max_value=None, min_value=None, string="input_directory/sequence_index_0_index_10.tfrecords"):
+        a = os.path.split(string)[-1].split("_")
+        if not len(a) == 5:
+            raise ("The tfrecords pattern does not match the expected pattern, for instanct: 'sequence_index_0_to_10.tfrecords'") 
+        min_index = int(a[2])
+        max_index = int(a[4].split(".")[0])
+        if min_index >= min_value and max_index <= max_value:
+            return True
+        else:
+            return False
 
+    def get_example_info(self):
+        """
+         Get the data information from tfrecord file
+        """
+        example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
+        dict_message = MessageToDict(tf.train.Example.FromString(example))
+        feature = dict_message['features']['feature']
+        print("features in dataset:",feature.keys())
+        video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height',
+                                                                                        'width', 'channels'])
+        self.sequence_length = video_shape[0]
+        self.image_shape = video_shape[1:]
 
     def num_examples_per_epoch(self):
-        with open(os.path.join(self.input_dir, 'number_squences.txt'), 'r') as sequence_lengths_file:
-            sequence_lengths = sequence_lengths_file.readlines()
-        sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
-        return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
-
-    def filter(self, serialized_example):
-        return tf.convert_to_tensor(True)
-
-
-    def make_dataset_v2(self, batch_size):
+        """
+        Calculate how many tfrecords samples in the train/val/test
+        """
+        # count how many tfrecords files for train/val/testing
+        len_fnames = len(self.filenames)
+        num_seq_file = os.path.join(self.input_dir, 'number_sequences.txt')
+        with open(num_seq_file, 'r') as dfile:
+             num_seqs = dfile.readlines()
+        num_sequences = [int(num_seq.strip()) for num_seq in num_seqs]
+        self.num_examples_per_epoch  = len_fnames * num_sequences[0]
+
+        return self.num_examples_per_epoch
+
+
+    def make_dataset(self, batch_size):
+        """
+        Prepare batch_size dataset fed into to the models.
+        If the data are from training dataset,then the data is shuffled;
+        If the data are from val dataset, the shuffle var will be decided by the hparams.shuffled_on_val;
+        if the data are from test dataset, the data will not be shuffled
+        args:
+              batch_size: int, the size of samples fed into the models per iteration
+        """
+        self.num_epochs = self.hparams.max_epochs
         def parser(serialized_example):
             seqs = OrderedDict()
             keys_to_features = {
-                'width': tf.FixedLenFeature([], tf.int64),
-                'height': tf.FixedLenFeature([], tf.int64),
-                'sequence_length': tf.FixedLenFeature([], tf.int64),
-                'channels': tf.FixedLenFeature([], tf.int64),
-                'images/encoded': tf.VarLenFeature(tf.float32)
-            }
-            
-            # for i in range(20):
-            #     keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string)
+                 'width': tf.FixedLenFeature([], tf.int64),
+                 'height': tf.FixedLenFeature([], tf.int64),
+                 'sequence_length': tf.FixedLenFeature([], tf.int64),
+                 'channels': tf.FixedLenFeature([],tf.int64),
+                 'images/encoded': tf.VarLenFeature(tf.float32)
+             }
             parsed_features = tf.parse_single_example(serialized_example, keys_to_features)
-            print ("Parse features", parsed_features)
             seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"])
-            #width = tf.sparse_tensor_to_dense(parsed_features["width"])
-           # height = tf.sparse_tensor_to_dense(parsed_features["height"])
-           # channels  = tf.sparse_tensor_to_dense(parsed_features["channels"])
-           # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"])
-            images = []
-            print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]))
-            images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new")
+            print("Image shape {}, {},{},{}".format(self.sequence_length,self.image_shape[0],self.image_shape[1],
+                                                    self.image_shape[2]))
+            images = tf.reshape(seq, [self.sequence_length,self.image_shape[0],self.image_shape[1],
+                                      self.image_shape[2]], name = "reshape_new")
             seqs["images"] = images
             return seqs
         filenames = self.filenames
-        print ("FILENAMES",filenames)
-	    #TODO:
-	    #temporal_filenames = self.temporal_filenames
         shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
         if shuffle:
             random.shuffle(filenames)
-        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)  # todo: what is buffer_size
-        print("files", self.filenames)
-        print("mode", self.mode)
-        dataset = dataset.filter(self.filter)
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)
         if shuffle:
-            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count=self.num_epochs))
         else:
             dataset = dataset.repeat(self.num_epochs)
-
+        if self.mode == "val": dataset = dataset.repeat(20)
         num_parallel_calls = None if shuffle else 1
         dataset = dataset.apply(tf.contrib.data.map_and_batch(
             parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
-        #dataset = dataset.map(parser)
-        # num_parallel_calls = None if shuffle else 1  # for reproducibility (e.g. sampled subclips from the test set)
-        # dataset = dataset.apply(tf.contrib.data.map_and_batch(
-        #    _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) #  Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs
-        dataset = dataset.prefetch(batch_size)  # Bing: Take the data to buffer inorder to save the waiting time for GPU
+        dataset = dataset.prefetch(batch_size)
         return dataset
 
-
-
     def make_batch(self, batch_size):
-        dataset = self.make_dataset_v2(batch_size)
+        dataset = self.make_dataset(batch_size)
         iterator = dataset.make_one_shot_iterator()
         return iterator.get_next()
 
@@ -129,108 +220,8 @@ def _floats_feature(value):
 def _int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
-def save_tf_record(output_fname, sequences):
-    with tf.python_io.TFRecordWriter(output_fname) as writer:
-        for i in range(len(sequences)):
-            sequence = sequences[:,i,:,:,:] 
-            num_frames = len(sequence)
-            height, width = sequence[0,:,:,0].shape
-            encoded_sequence = np.array([list(image) for image in sequence])
-            features = tf.train.Features(feature={
-                'sequence_length': _int64_feature(num_frames),
-                'height': _int64_feature(height),
-                'width': _int64_feature(width),
-                'channels': _int64_feature(1),
-                'images/encoded': _floats_feature(encoded_sequence.flatten()),
-            })
-            example = tf.train.Example(features=features)
-            writer.write(example.SerializeToString())
-
-def read_frames_and_save_tf_records(output_dir,dat_npz, seq_length=20, sequences_per_file=128, height=64, width=64):#Bing: original 128
-    """
-    Read the moving_mnst data which is npz format, and save it to tfrecords files
-    The shape of dat_npz is [seq_length,number_samples,height,width]
-    moving_mnst only has one channel
-
-    """
-    os.makedirs(output_dir,exist_ok=True)
-    idx = 0
-    num_samples = dat_npz.shape[1]
-    dat_npz = np.expand_dims(dat_npz, axis=4) #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel]
-    print("data_npz_shape",dat_npz.shape)
-    dat_npz = dat_npz.astype(np.float32)
-    dat_npz /= 255.0 #normalize RGB codes by dividing it to the max RGB value 
-    while idx < num_samples - sequences_per_file:
-        sequences = dat_npz[:,idx:idx+sequences_per_file,:,:,:]
-        output_fname = 'sequence_{}_{}.tfrecords'.format(idx,idx+sequences_per_file)
-        output_fname = os.path.join(output_dir, output_fname)
-        save_tf_record(output_fname, sequences)
-        idx = idx + sequences_per_file
-    return None
-
-
-def write_sequence_file(output_dir,seq_length,sequences_per_file):    
-    partition_names = ["train","val","test"]
-    for partition_name in partition_names:
-        save_output_dir = os.path.join(output_dir,partition_name)
-        tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords"))
-        print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter))
-        sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w')
-        for i in range(tfCounter*sequences_per_file):
-            sequence_lengths_file.write("%d\n" % seq_length)
-        sequence_lengths_file.close()
-
-
-def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"):
-    """
-    Plot the seq images 
-    """
-
-    if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)")
-    img_len = imgs.shape[0]
-    fig = plt.figure(figsize=(18,6))
-    gs = gridspec.GridSpec(1, 10)
-    gs.update(wspace = 0., hspace = 0.)
-    for i in range(img_len):
-        ax1 = plt.subplot(gs[i])
-        plt.imshow(imgs[i] ,cmap = 'jet')
-        plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = [])
-    plt.savefig(os.path.join(output_png_dir, label + "_" +   str(idx) +  ".jpg"))
-    print("images_saved")
-    plt.clf()
+
 
 
     
-    
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking")
-    parser.add_argument("output_dir", type=str)
-    parser.add_argument("-sequences_per_file",type=int,default=2)
-    args = parser.parse_args()
-    current_path = os.getcwd()
-    data = np.load(os.path.join(args.input_dir,"mnist_test_seq.npy"))
-    print("data in minist_test_Seq shape",data.shape)
-    seq_length =  data.shape[0]
-    height = data.shape[2]
-    width = data.shape[3]
-    num_samples = data.shape[1] 
-    max_npz = np.max(data)
-    min_npz = np.min(data)
-    print("max_npz,",max_npz)
-    print("min_npz",min_npz)
-    #Todo need to discuss how to split the data, since we have totally 10000 samples, the origin paper convLSTM used 10000 as training, 2000 as validation and 3000 for testing
-    dat_train = data[:,:6000,:,:]
-    dat_val = data[:,6000:7000,:,:]
-    dat_test = data[:,7000:,:]
-    #plot_seq_imgs(dat_test[10:,0,:,:],output_png_dir="/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist/convLSTM",idx=1,label="Ground Truth from npz")
-    #save train
-    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"train"),dat_train, seq_length=20, sequences_per_file=40, height=height, width=width)
-    #save val
-    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"val"),dat_val, seq_length=20, sequences_per_file=40, height=height, width=width)
-    #save test     
-    #read_frames_and_save_tf_records(os.path.join(args.output_dir,"test"),dat_test, seq_length=20, sequences_per_file=40, height=height, width=width)
-    #write_sequence_file(output_dir=args.output_dir,seq_length=20,sequences_per_file=40)
-if __name__ == '__main__':
-     main()
 
diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
index 960f608deed07e715190cdecb38efeb2eb4c5ace..2053aeed83a3606804af959e1c422d5cb39723a7 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
@@ -12,6 +12,10 @@ from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
 from .mcnet_model import McNetVideoPredictionModel
 from .test_model import TestModelVideoPredictionModel
 from model_modules.model_architectures import known_models
+from .vanilla_GAN_model import  VanillaGANVideoPredictionModel
+from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel
+
+
 
 def get_model_class(model):
     model_mappings = known_models()
diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..092cf81a0db1adeb3c3bac65cb3bc92e54e3e4c5
--- /dev/null
+++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
@@ -0,0 +1,341 @@
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong,Yanji"
+__date__ = "2021-04-13"
+
+from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
+import tensorflow as tf
+from model_modules.video_prediction.layers import layer_def as ld
+from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
+from tensorflow.contrib.training import HParams
+from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
+
+class batch_norm(object):
+  def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
+    with tf.variable_scope(name):
+      self.epsilon  = epsilon
+      self.momentum = momentum
+      self.name = name
+
+  def __call__(self, x, train=True):
+    return tf.contrib.layers.batch_norm(x,
+                      decay=self.momentum,
+                      updates_collections=None,
+                      epsilon=self.epsilon,
+                      scale=True,
+                      is_training=train,
+                      scope=self.name)
+
+class ConvLstmGANVideoPredictionModel(object):
+    def __init__(self, mode='train', hparams_dict=None):
+        """
+        This is class for building convLSTM_GAN architecture by using updated hparameters
+        args:
+             mode   :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
+             hparams_dict: dict, the dictionary contains the hparaemters names and values
+        """
+        self.mode = mode
+        self.hparams_dict = hparams_dict
+        self.hparams = self.parse_hparams()        
+        self.learning_rate = self.hparams.lr
+        self.total_loss = None
+        self.context_frames = self.hparams.context_frames
+        self.sequence_length = self.hparams.sequence_length
+        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
+        self.max_epochs = self.hparams.max_epochs
+        self.loss_fun = self.hparams.loss_fun
+        self.batch_size = self.hparams.batch_size
+        self.recon_weight = self.hparams.recon_weight
+        self.bd1 = batch_norm(name = "dis1")
+        self.bd2 = batch_norm(name = "dis2")
+        self.bd3 = batch_norm(name = "dis3")   
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+        
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+
+    def get_default_hparams_dict(self):
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            sequence_length : the number of frames in the video sequence 
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+            recon_wegiht    : the weight for reconstrution loss
+            """
+        hparams = dict(
+            context_frames=12,
+            sequence_length=24,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            loss_fun = "cross_entropy",
+            shuffle_on_val= True,
+            recon_weight=0.99,
+          
+         )
+        return hparams
+
+
+    def build_graph(self, x):
+        self.is_build_graph = False
+        self.inputs = x
+        self.x = x["images"]
+        self.width = self.x.shape.as_list()[3]
+        self.height = self.x.shape.as_list()[2]
+        self.channels = self.x.shape.as_list()[4]
+        self.global_step = tf.train.get_or_create_global_step()
+        original_global_variables = tf.global_variables()
+        # Architecture
+        self.define_gan()
+        #This is the loss function (RMSE):
+        #This is loss function only for 1 channel (temperature RMSE)
+        #generator los
+        self.total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss
+        self.D_loss =  (1-self.recon_weight) * self.D_loss
+        if self.mode == "train":
+            if self.recon_weight == 1:
+                print("Only train generator- convLSTM") 
+                self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) 
+            else:
+                print("Training distriminator")
+                self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars)
+                with tf.control_dependencies([self.D_solver]):
+                    print("Training generator....")
+                    self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars)
+                with tf.control_dependencies([self.G_solver]):
+                    self.train_op = tf.assign_add(self.global_step,1)
+        else:
+           self.train_op = None 
+
+        self.outputs = {}
+        self.outputs["gen_images"] = self.gen_images
+        self.outputs["total_loss"] = self.total_loss
+        # Summary op
+        tf.summary.scalar("total_loss", self.total_loss)
+        tf.summary.scalar("D_loss", self.D_loss)
+        tf.summary.scalar("G_loss", self.G_loss)
+        tf.summary.scalar("D_loss_fake", self.D_loss_fake) 
+        tf.summary.scalar("D_loss_real", self.D_loss_real)
+        tf.summary.scalar("recon_loss",self.recon_loss)
+        self.summary_op = tf.summary.merge_all()
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        self.is_build_graph = True
+        return self.is_build_graph 
+    
+    def get_noise(self):
+        """
+        Function for creating noise: Given the dimensions (n_batch,n_seq, n_height, n_width, channel)
+        """ 
+        self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.batch_size, self.sequence_length, self.height, self.width, self.channels])
+        return self.noise
+     
+    @staticmethod
+    def lrelu(x, leak=0.2, name="lrelu"):
+        return tf.maximum(x, leak*x)
+
+    @staticmethod    
+    def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
+        shape = input_.get_shape().as_list()
+
+        with tf.variable_scope(scope or "Linear"):
+            matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
+                     tf.random_normal_initializer(stddev=stddev))
+            bias = tf.get_variable("bias", [output_size],
+            initializer=tf.constant_initializer(bias_start))
+            if with_w:
+                return tf.matmul(input_, matrix) + bias, matrix, bias
+            else:
+                return tf.matmul(input_, matrix) + bias
+     
+    @staticmethod
+    def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"):
+        with tf.variable_scope(name):
+            w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
+                  initializer=tf.truncated_normal_initializer(stddev=stddev))
+            conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
+
+            biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
+            conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
+
+        return conv
+
+    @staticmethod
+    def bn(x, scope):
+        return tf.contrib.layers.batch_norm(x,
+                                        decay=0.9,
+                                        updates_collections=None,
+                                        epsilon=1e-5,
+                                        scale=True,
+                                        scope=scope)
+
+    def generator(self):
+        """
+        Function to build up the generator architecture
+        args:
+            input images: a input tensor with dimension (n_batch,sequence_length,height,width,channel)
+        """
+        with tf.variable_scope("generator",reuse=tf.AUTO_REUSE):
+            layer_gen = self.convLSTM_network(self.x)
+            layer_gen_pred = layer_gen[:,self.context_frames-1:,:,:,:]
+        return layer_gen
+
+
+    def discriminator(self,vid):
+        """
+        Function that get discriminator architecture      
+        """
+        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
+            conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis1")
+            conv1 = ConvLstmGANVideoPredictionModel.lrelu(conv1)
+            conv2 = tf.layers.conv3d(conv1,128,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis2")
+            conv2 = ConvLstmGANVideoPredictionModel.lrelu(self.bd1(conv2))
+            conv3 = tf.layers.conv3d(conv2,256,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis3")
+            conv3 = ConvLstmGANVideoPredictionModel.lrelu(self.bd2(conv3))
+            conv4 = tf.layers.conv3d(conv3,512,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis4")
+            conv4 = ConvLstmGANVideoPredictionModel.lrelu(self.bd3(conv4))
+            conv5 = tf.layers.conv3d(conv4,1,kernel_size=[2,4,4],strides=[1,1,1],padding="SAME",name="dis5")
+            conv5 = tf.reshape(conv5, [-1,1])
+            conv5sigmoid = tf.nn.sigmoid(conv5)
+            return conv5sigmoid,conv5
+
+    def discriminator0(self,image):
+        """
+        Function that get discriminator architecture      
+        """
+        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
+            layer_disc = self.convLSTM_network(image)
+            layer_disc = layer_disc[:,self.context_frames-1:self.context_frames,:,:, 0:1]
+        return layer_disc
+
+    def discriminator1(self,sequence):
+        """
+        https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py
+        Function that give the possibility of a sequence of frames is ture of false 
+        the input squence shape is like [batch_size,time_seq_length,height,width,channel]  (e.g., self.x[:,:self.context_frames,:,:,:])
+        """
+        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
+            print(sequence.shape)
+            x = sequence[:,:,:,:,0:1] # extract targeted variable
+            x = tf.transpose(x, [0,2,3,1,4]) # sequence shape is like: [batch_size,height,width,time_seq_length]
+            x = tf.reshape(x,[x.shape[0],x.shape[1],x.shape[2],x.shape[3]])
+            print(x.shape)
+            net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
+            net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'),scope='d_bn2'))
+            net = tf.reshape(net, [self.batch_size, -1])
+            net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.linear(net, 1024, scope='d_fc3'),scope='d_bn3'))
+            out_logit = ConvLstmGANVideoPredictionModel.linear(net, 1, scope='d_fc4')
+            out = tf.nn.sigmoid(out_logit)
+            print(out.shape)
+        return out, out_logit
+
+    def get_disc_loss(self):
+        """
+        Return the loss of discriminator given inputs
+        """
+          
+        real_labels = tf.ones_like(self.D_real)
+        gen_labels = tf.zeros_like(self.D_fake)
+        self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels))
+        self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=gen_labels))
+        self.D_loss = self.D_loss_real + self.D_loss_fake
+        return self.D_loss
+
+
+    def get_gen_loss(self):
+        """
+        Param:
+	    num_images: the number of images the generator should produce, which is also the lenght of the real image
+            z_dim     : the dimension of the noise vector, a scalar
+        Return the loss of generator given inputs
+        """
+        real_labels = tf.ones_like(self.D_fake)
+        self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=real_labels))
+        return self.G_loss         
+   
+    def get_vars(self):
+        """
+        Get trainable variables from discriminator and generator
+        """
+        print("trinable_varialbes", len(tf.trainable_variables()))
+        self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
+        self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
+        print("self.disc_vars",self.disc_vars)
+        print("self.gen_vars",self.gen_vars)
+ 
+  
+    def define_gan(self):
+        """
+        Define gan architectures
+        """
+        self.noise = self.get_noise()
+        self.gen_images = self.generator()
+        #!!!! the input of discriminator should be changed when use different discriminators
+        self.D_real, self.D_real_logits = self.discriminator(self.x[:,self.context_frames:,:,:,:])
+        self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,self.context_frames-1:,:,:,:])
+        self.get_gen_loss()
+        self.get_disc_loss()
+        self.get_vars()
+        if self.loss_fun == "rmse":
+            self.recon_loss = tf.reduce_mean(tf.square(self.x[:, self.context_frames:,:,:,0] - self.gen_images[:,self.context_frames-1:,:,:,0]))
+        elif self.loss_fun == "cross_entropy":
+            x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
+            x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,self.context_frames-1:,:,:,0],[-1])
+            bce = tf.keras.losses.BinaryCrossentropy()
+            self.recon_loss = bce(x_flatten,x_hat_predict_frames_flatten)
+        else:
+            raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")   
+
+
+    @staticmethod
+    def convLSTM_cell(inputs, hidden):
+        y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
+        channels = inputs.get_shape()[-1]
+        # conv lstm cell
+        cell_shape = y_0.get_shape().as_list()
+        channels = cell_shape[-1]
+        with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
+            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64)
+            if hidden is None:
+                hidden = cell.zero_state(y_0, tf.float32)
+            output, hidden = cell(y_0, hidden)
+        output_shape = output.get_shape().as_list()
+        z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
+        #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction
+        x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
+        print('x_hat shape is: ',x_hat.shape)
+        return x_hat, hidden
+
+    def convLSTM_network(self,x):
+        network_template = tf.make_template('network',VanillaConvLstmVideoPredictionModel.convLSTM_cell)  # make the template to share the variables
+        # create network
+        x_hat = []
+        
+        #This is for training (optimization of convLSTM layer)
+        hidden_g = None
+        for i in range(self.sequence_length-1):
+            if i < self.context_frames:
+                x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g)
+            else:
+                x_1_g, hidden_g = network_template(x_1_g, hidden_g)
+            x_hat.append(x_1_g)
+
+        # pack them all together
+        x_hat = tf.stack(x_hat)
+        self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim  ???? yan: why?
+        print('self.x_hat shape is: ',self.x_hat.shape)
+        return self.x_hat
+     
+   
+   
diff --git a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py
index a946bd555a603fd9be14306929e0a8e722a24673..2c755417c27bca22b2a20dc6816add43f160c5a6 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py
@@ -3,22 +3,13 @@ __author__ = "Bing Gong"
 __date__ = "2020-08-22"
 
 
-import collections
-import functools
 import itertools
-from collections import OrderedDict
 import numpy as np
 import tensorflow as tf
-from tensorflow.python.util import nest
-from model_modules.video_prediction import ops, flow_ops
+
+from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
 from model_modules.video_prediction.models import BaseVideoPredictionModel
-from model_modules.video_prediction.models import networks
 from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
-from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
-from model_modules.video_prediction.utils import tf_utils
-from datetime import datetime
-from pathlib import Path
-from model_modules.video_prediction.layers import layer_def as ld
 from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
 from model_modules.video_prediction.layers.mcnet_ops import *
 from model_modules.video_prediction.utils.mcnet_utils import *
@@ -32,7 +23,7 @@ class McNetVideoPredictionModel(BaseVideoPredictionModel):
         self.lr = self.hparams.lr
         self.context_frames = self.hparams.context_frames
         self.sequence_length = self.hparams.sequence_length
-        self.predict_frames = self.sequence_length - self.context_frames
+        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
         self.df_dim = self.hparams.df_dim
         self.gf_dim = self.hparams.gf_dim
         self.alpha = self.hparams.alpha
diff --git a/video_prediction_tools/model_modules/video_prediction/models/model_helpers.py b/video_prediction_tools/model_modules/video_prediction/models/model_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a81f4631e3edde6efcfc807403ed9b061762e4c3
--- /dev/null
+++ b/video_prediction_tools/model_modules/video_prediction/models/model_helpers.py
@@ -0,0 +1,28 @@
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong, Michael Langguth"
+__date__ = "2021-20-05"
+
+"""
+Some auxiliary functions that can be used by any video prediction model
+"""
+
+
+def set_and_check_pred_frames(seq_length, context_frames):
+    """
+    Checks if sequence length and context_frames are set properly and returns number of frames to be predicted.
+    :param seq_length: number of frames/images per sequences
+    :param context_frames: number of context frames/images
+    :return: number of predicted frames
+    """
+
+    method = set_and_check_pred_frames.__name__
+
+    # sanity checks
+    assert isinstance(seq_length, int), "%{0}: Sequence length (seq_length) must be an integer".format(method)
+    assert isinstance(context_frames, int), "%{0}: Number of context frames must be an integer".format(method)
+
+    if seq_length > context_frames:
+        return seq_length - context_frames
+    else:
+        raise ValueError("%{0}: Sequence length ({1}) must be larger than context frames ({2})."
+                         .format(method, seq_length, context_frames))
\ No newline at end of file
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..112eaf31e1d9961e402ddbac1a91b37a6b7b9b90
--- /dev/null
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py
@@ -0,0 +1,230 @@
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong"
+__date__ = "2021=01-05"
+
+
+
+"""
+This code implement take the following as references:
+1) https://stackabuse.com/introduction-to-gans-with-python-and-tensorflow/
+2) cousera GAN courses
+3) https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py
+"""
+
+import tensorflow as tf
+
+from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
+from model_modules.video_prediction.layers import layer_def as ld
+from tensorflow.contrib.training import HParams
+
+class VanillaGANVideoPredictionModel(object):
+    def __init__(self, mode='train', hparams_dict=None):
+        """
+        This is class for building vanilla GAN architecture by using updated hparameters
+        args:
+             mode   :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
+             hparams_dict: dict, the dictionary contains the hparaemters names and values
+        """
+        self.mode = mode
+        self.hparams_dict = hparams_dict
+        self.hparams = self.parse_hparams()        
+        self.learning_rate = self.hparams.lr
+        self.total_loss = None
+        self.context_frames = self.hparams.context_frames
+        self.sequence_length = self.hparams.sequence_length
+        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
+        self.max_epochs = self.hparams.max_epochs
+        self.loss_fun = self.hparams.loss_fun
+        self.batch_size = self.hparams.batch_size
+        self.z_dim = self.hparams.z_dim  # dim of noise-vector
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+        
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+
+    def get_default_hparams_dict(self):
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            sequence_length : the number of frames in the video sequence 
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+        """
+        hparams = dict(
+            context_frames=12,
+            sequence_length=24,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            loss_fun = "cross_entropy",
+            shuffle_on_val= True,
+            z_dim = 32,
+         )
+        return hparams
+
+
+    def build_graph(self, x):
+        self.is_build_graph = False
+        self.x = x["images"]
+        self.width = self.x.shape.as_list()[3]
+        self.height = self.x.shape.as_list()[2]
+        self.channels = self.x.shape.as_list()[4]
+        self.n_samples = self.x.shape.as_list()[0] * self.x.shape.as_list()[1]
+        self.x = tf.reshape(self.x, [-1, self.height,self.width,self.channels]) 
+        self.global_step = tf.train.get_or_create_global_step()
+        original_global_variables = tf.global_variables()
+        # Architecture
+        self.define_gan()
+        #This is the loss function (RMSE):
+        #This is loss function only for 1 channel (temperature RMSE)
+        if self.mode == "train":
+            self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars)
+            with tf.control_dependencies([self.D_solver]):
+                self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.G_loss, var_list=self.gen_vars)
+            with tf.control_dependencies([self.G_solver]):
+                self.train_op = tf.assign_add(self.global_step,1)
+        else:
+           self.train_op = None 
+        self.total_loss = self.G_loss + self.D_loss 
+        self.outputs = {}
+        self.outputs["gen_images"] = self.gen_images
+        self.outputs["total_loss"] = self.total_loss
+        # Summary op
+        self.loss_summary = tf.summary.scalar("total_loss", self.G_loss + self.D_loss)
+        self.summary_op = tf.summary.merge_all()
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        self.is_build_graph = True
+        return self.is_build_graph 
+    
+    def get_noise(self):
+        """
+        Function for creating noise: Given the dimensions (n_samples,z_dim)
+        """ 
+        self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.n_samples, self.height, self.width, self.channels])
+        return self.noise
+
+    def get_generator_block(self,inputs,output_dim,idx):
+       
+        """
+        Generator Block
+        Function for return a neural network of the generator given input and output dimensions
+        args:
+            inputs : the  input vector
+            output_dim: the dimeniosn of output vector
+        return:
+             a generator neural network layer, with a convolutional layers followed by batch normalization and a relu activation
+       
+        """
+        output1 = ld.conv_layer(inputs,kernel_size=2,stride=1,num_features=output_dim,idx=idx,activate="linear")
+        output2 = ld.bn_layers(output1,idx,is_training=False)
+        output3 = tf.nn.relu(output2)
+        return output3
+
+
+    def generator(self,hidden_dim):
+        """
+        Function to build up the generator architecture
+        args:
+            noise: a noise tensor with dimension (n_samples,height,width,channel)
+            hidden_dim: the inner dimension
+        """
+        with tf.variable_scope("generator",reuse=tf.AUTO_REUSE):
+            layer1 = self.get_generator_block(self.noise,hidden_dim,1)
+            layer2 = self.get_generator_block(layer1,hidden_dim*2,2)
+            layer3 = self.get_generator_block(layer2,hidden_dim*4,3)
+            layer4 = self.get_generator_block(layer3,hidden_dim*8,4)
+            layer5 = ld.conv_layer(layer4,kernel_size=2,stride=1,num_features=self.channels,idx=5,activate="linear")
+            layer6 = tf.nn.sigmoid(layer5,name="6_conv")
+        print("layer6",layer6)
+        return layer6
+
+
+
+    def get_discriminator_block(self,inputs,output_dim,idx):
+
+        """
+        Distriminator block
+        Function for ruturn a neural network of a descriminator given input and output dimensions
+
+        args:
+           inputs : the dimension of input vector
+           output_dim: the dimension of output dim
+           idx:      : the index for the namespace of this block
+        Return:
+           a distriminator neural network layer with a convolutional layers followed by a leakyRelu function 
+        """
+        output1 = ld.conv_layer(inputs,2,stride=1,num_features=output_dim,idx=idx,activate="linear")
+        output2 = tf.nn.leaky_relu(output1)
+        return output2
+
+
+    def discriminator(self,image,hidden_dim):
+        """
+        Function that get discriminator architecture      
+        """
+        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
+            layer1 = self.get_discriminator_block(image,hidden_dim,idx=1)
+            layer2 = self.get_discriminator_block(layer1,hidden_dim*4,idx=2)
+            layer3 = self.get_discriminator_block(layer2,hidden_dim*2,idx=3)
+            layer4 = self.get_discriminator_block(layer3, self.channels,idx=4)
+            layer5 = tf.nn.sigmoid(layer4)
+        return layer5
+
+
+    def get_disc_loss(self):
+        """
+        Return the loss of discriminator given inputs
+        """
+          
+        real_labels = tf.ones_like(self.D_real)
+        gen_labels = tf.zeros_like(self.D_fake)
+        D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real, labels=real_labels))
+        D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=gen_labels))
+        self.D_loss = D_loss_real + D_loss_fake
+        return self.D_loss
+
+
+    def get_gen_loss(self):
+        """
+        Param:
+	    num_images: the number of images the generator should produce, which is also the lenght of the real image
+            z_dim     : the dimension of the noise vector, a scalar
+        Return the loss of generator given inputs
+        """
+        real_labels = tf.ones_like(self.gen_images)
+        self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=real_labels))
+        return self.G_loss         
+   
+    def get_vars(self):
+        """
+        Get trainable variables from discriminator and generator
+        """
+        self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
+        self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
+       
+ 
+  
+    def define_gan(self):
+        """
+        Define gan architectures
+        """
+        self.noise = self.get_noise()
+        self.gen_images = self.generator(hidden_dim=8)
+        self.D_real = self.discriminator(self.x,hidden_dim=8)
+        self.D_fake = self.discriminator(self.gen_images,hidden_dim=8)
+        self.get_gen_loss()
+        self.get_disc_loss()
+        self.get_vars()
+      
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
index e78f973b2aef760e1fccc24d296b31ee67f2e0c8..570ece3e8e752a0520bf1593bc5e0b922cece3a0 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
@@ -2,25 +2,14 @@ __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth"
 __date__ = "2020-11-05"
 
-import collections
-import functools
-import itertools
-from collections import OrderedDict
-import numpy as np
+from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
 import tensorflow as tf
-from tensorflow.python.util import nest
-from model_modules.video_prediction import ops, flow_ops
-from model_modules.video_prediction.models import BaseVideoPredictionModel
-from model_modules.video_prediction.models import networks
-from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
-from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
-from model_modules.video_prediction.utils import tf_utils
-from datetime import datetime
-from pathlib import Path
 from model_modules.video_prediction.layers import layer_def as ld
 from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
 from tensorflow.contrib.training import HParams
 
+
+
 class VanillaConvLstmVideoPredictionModel(object):
     def __init__(self, mode='train', hparams_dict=None):
         """
@@ -36,7 +25,7 @@ class VanillaConvLstmVideoPredictionModel(object):
         self.total_loss = None
         self.context_frames = self.hparams.context_frames
         self.sequence_length = self.hparams.sequence_length
-        self.predict_frames = self.sequence_length - self.context_frames
+        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
         self.max_epochs = self.hparams.max_epochs
         self.loss_fun = self.hparams.loss_fun
 
@@ -67,17 +56,18 @@ class VanillaConvLstmVideoPredictionModel(object):
         hparams = dict(
             context_frames=10,
             sequence_length=20,
-            max_epochs = 20,
-            batch_size = 40,
-            lr = 0.001,
-            loss_fun = "cross_entropy",
-            shuffle_on_val= True,
+            max_epochs=20,
+            batch_size=40,
+            lr=0.001,
+            loss_fun="cross_entropy",
+            shuffle_on_val=True,
         )
         return hparams
 
 
     def build_graph(self, x):
         self.is_build_graph = False
+        self.inputs = x
         self.x = x["images"]
         self.global_step = tf.train.get_or_create_global_step()
         original_global_variables = tf.global_variables()
@@ -112,24 +102,6 @@ class VanillaConvLstmVideoPredictionModel(object):
         self.is_build_graph = True
         return self.is_build_graph 
 
-    @staticmethod
-    def convLSTM_cell(inputs, hidden):
-        y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
-        channels = inputs.get_shape()[-1]
-        # conv lstm cell
-        cell_shape = y_0.get_shape().as_list()
-        channels = cell_shape[-1]
-        with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
-            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64)
-            if hidden is None:
-                hidden = cell.zero_state(y_0, tf.float32)
-            output, hidden = cell(y_0, hidden)
-        output_shape = output.get_shape().as_list()
-        z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
-        #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction
-        x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
-        return x_hat, hidden
-
     def convLSTM_network(self):
         network_template = tf.make_template('network',
                                             VanillaConvLstmVideoPredictionModel.convLSTM_cell)  # make the template to share the variables
@@ -150,3 +122,41 @@ class VanillaConvLstmVideoPredictionModel(object):
         self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim
         self.x_hat_predict_frames = self.x_hat[:,self.context_frames-1:,:,:,:]
 
+    @staticmethod
+    def convLSTM_cell(inputs, hidden):
+        y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
+        channels = inputs.get_shape()[-1]
+        # conv lstm cell
+        cell_shape = y_0.get_shape().as_list()
+        channels = cell_shape[-1]
+        with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
+            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64)
+            if hidden is None:
+                hidden = cell.zero_state(y_0, tf.float32)
+            output, hidden = cell(y_0, hidden)
+        output_shape = output.get_shape().as_list()
+        z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
+        #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction
+        x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
+        return x_hat, hidden
+
+    @staticmethod
+    def set_and_check_pred_frames(seq_length, context_frames):
+        """
+        Checks if sequence length and context_frames are set properly and returns number of frames to be predicted.
+        :param seq_length: number of frames/images per sequences
+        :param context_frames: number of context frames/images
+        :return: number of predicted frames
+        """
+
+        method = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames.__name__
+
+        # sanity checks
+        assert isinstance(seq_length, int), "%{0}: Sequence length (seq_length) must be an integer".format(method)
+        assert isinstance(context_frames, int), "%{0}: Number of context frames must be an integer".format(method)
+
+        if seq_length > context_frames:
+            return seq_length-context_frames
+        else:
+            raise ValueError("%{0}: Sequence length ({1}) must be larger than context frames ({2})."
+                             .format(method, seq_length, context_frames))
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py
index 986e3626fe0746c2c714ee9fa9a76ad873044415..3e74b23b9d4969544ef0e560e30cd19e4a983f6b 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py
@@ -3,21 +3,8 @@ __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong"
 __date__ = "2020-09-01"
 
-import collections
-import functools
-import itertools
-from collections import OrderedDict
-import numpy as np
+from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
 import tensorflow as tf
-from tensorflow.python.util import nest
-from model_modules.video_prediction import ops, flow_ops
-from model_modules.video_prediction.models import BaseVideoPredictionModel
-from model_modules.video_prediction.models import networks
-from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
-from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
-from model_modules.video_prediction.utils import tf_utils
-from datetime import datetime
-from pathlib import Path
 from model_modules.video_prediction.layers import layer_def as ld
 from tensorflow.contrib.training import HParams
 
@@ -37,7 +24,7 @@ class VanillaVAEVideoPredictionModel(object):
         self.total_loss = None
         self.context_frames = self.hparams.context_frames
         self.sequence_length = self.hparams.sequence_length
-        self.predict_frames = self.sequence_length - self.context_frames
+        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
         self.max_epochs = self.hparams.max_epochs
         self.nz = self.hparams.nz
         self.loss_fun = self.hparams.loss_fun
diff --git a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py
index a1aa0dba9c5ca1e230c254a4d0a7b5439db0d5e3..d420ed1be8460ade5fae3302960de8b761f49f72 100755
--- a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py
+++ b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py
@@ -20,7 +20,7 @@ class Config_Preprocess1(Config_runscript_base):
 
     cls_name = "Config_Preprocess1"#.__name__
 
-    nvars = 3                                    # number of variables required for training
+    nvars_default = 3                                    # number of variables required for training
 
     def __init__(self, venv_name, lhpc):
         super().__init__(venv_name, lhpc)
@@ -35,7 +35,7 @@ class Config_Preprocess1(Config_runscript_base):
         # initialize additional runscript-specific attributes to be set via keyboard interaction
         self.destination_dir = None
         self.years = None
-        self.variables = [None] * self.nvars
+        self.variables = []
         self.sw_corner = [-999., -999.]  # [np.nan, np.nan]
         self.nyx = [-999., -999.]  # [np.nan, np.nan]
         # list of variables to be written to runscript
@@ -54,8 +54,18 @@ class Config_Preprocess1(Config_runscript_base):
         """
         method_name = Config_Preprocess1.run_preprocess1.__name__
 
-        # get source_dir (no user interaction needed when directory tree is fixed)
-        self.source_dir = Config_Preprocess1.handle_source_dir(self, "extractedData")
+        src_dir_req_str = "Enter path to directory where netCDF-files of the ERA5 dataset are located " + \
+                          "(in yearly directories.). Just press enter if the default should be used."
+        sorurce_dir_err = NotADirectoryError("Passed directory does not exist.")
+        source_dir_str = Config_Preprocess1.keyboard_interaction(src_dir_req_str, Config_Preprocess1.src_dir_check,
+                                                                 sorurce_dir_err, ntries=3)
+        if not source_dir_str:
+            # standard source_dir
+            self.source_dir = Config_Preprocess1.handle_source_dir(self, "extractedData")
+            print("%{0}: The following standard base-directory obtained from runscript template was set: '{1}'".format(method_name, self.source_dir))
+        else:
+            self.source_dir = source_dir_str
+            Config_Preprocess1.get_subdir_list(self.source_dir)
 
         # get years for preprocessing step 1
         years_req_str = "Enter a comma-separated sequence of years from list above:"
@@ -86,8 +96,9 @@ class Config_Preprocess1(Config_runscript_base):
                                                            vars_err, ntries=2)
 
         vars_list = vars_str.split(",")
+        vars_list = [var.strip().lower() for var in vars_list]
         if len(vars_list) == 1:
-            self.variables = vars_list * Config_Preprocess1.nvars
+            self.variables = vars_list * Config_Preprocess1.nvars_default
         else:
             self.variables = [var.strip() for var in vars_list]
 
@@ -169,6 +180,29 @@ class Config_Preprocess1(Config_runscript_base):
                                                                                                            str(year)))
 
     # auxiliary functions for keyboard interaction
+    @staticmethod
+    def src_dir_check(srcdir, silent=False):
+        """
+        Checks if source directory exists. Also allows for empty strings. In this case, a default of the source
+        directory must be applied.
+        :param srcdir: directory path under which ERA5 netCDF-data is stored
+        :param silent: flag if print-statement are executed
+        :return: status with True confirming success
+        """
+        method = Config_Preprocess1.src_dir_check.__name__
+
+        status = False
+        if srcdir:
+            if os.path.isdir(srcdir):
+                status = True
+            else:
+                if not silent:
+                    print("%{0}: '{1}' does not exist.".format(method, srcdir))
+        else:
+            status = True
+
+        return status
+
     @staticmethod
     def check_data_indir(indir, silent=False, recursive=True):
         """
@@ -214,9 +248,10 @@ class Config_Preprocess1(Config_runscript_base):
         check_years = [year.strip().isnumeric() for year in years_list]
         status = all(check_years)
         if not status:
-            inds_bad = [i for i, e in enumerate(check_years) if e] #np.where(~np.array(check_years))[0]
+            inds_bad = [i for i, e in enumerate(check_years) if not e] #np.where(~np.array(check_years))[0]
             if not silent:
-                print("%{0}: The following comma-separated elements could not be interpreted as valid years:".format(method))
+                print("%{0}: The following comma-separated elements could not be interpreted as valid years:"
+                      .format(method))
                 for ind in inds_bad:
                     print(years_list[ind])
             return status
@@ -245,15 +280,15 @@ class Config_Preprocess1(Config_runscript_base):
         check_vars = [var.strip().lower() in known_vars for var in vars_list]
         status = all(check_vars)
         if not status:
-            inds_bad = [i for i, e in enumerate(check_vars) if e] # np.where(~np.array(check_vars))[0]
+            inds_bad = [i for i, e in enumerate(check_vars) if e]  # np.where(~np.array(check_vars))[0]
             if not silent:
-                print("%{0}: The following comma-separated elements are unknown variables:".format(method_name))
+                print("%{0}: The following comma-separated elements are unknown variables:".format(method))
                 for ind in inds_bad:
                     print(vars_list[ind])
             return status
 
-        if not (len(check_vars) == Config_Preprocess1.nvars or len(check_vars) == 1):
-            if not silent: print("%{0}: Unexpected number of variables passed.".method(method))
+        if not len(check_vars) >= 1:
+            if not silent: print("%{0}: Pass at least one input variable".format(method))
             status = False
 
         return status
@@ -321,9 +356,11 @@ class Config_Preprocess1(Config_runscript_base):
             else:
                 if not silent:
                     if not check_nyx[0]:
-                        print("%{0}: Number of grid points in meridional direction must be smaller than {1:d}".format(method, ny_max))
+                        print("%{0}: Number of grid points in meridional direction must be smaller than {1:d}"
+                              .format(method, ny_max))
                     if not check_nyx[1]:
-                        print("%{0}: Number of grid points in zonal direction must be smaller than {1:d}".format(method, nx_max))
+                        print("%{0}: Number of grid points in zonal direction must be smaller than {1:d}"
+                              .format(method, nx_max))
         else:
             if not silent: print("%{0}: Number of grid points must be integers.".format(method))