diff --git a/test/run_pytest.sh b/test/run_pytest.sh
index ed2b531562282d8fc6a34c07cf8e46cad1a83460..83220d34a51379e93add931ae6e03e9491b5bce4 100644
--- a/test/run_pytest.sh
+++ b/test/run_pytest.sh
@@ -2,7 +2,7 @@
 
 # Name of virtual environment 
 #VIRT_ENV_NAME="vp_new_structure"
-VIRT_ENV_NAME="env_hdfml"
+VIRT_ENV_NAME="juwels_env"
 
 if [ -z ${VIRTUAL_ENV} ]; then
    if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then
diff --git a/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json
new file mode 100644
index 0000000000000000000000000000000000000000..bd0357a180631f1ba7f0d0b1732af1d18aa878c3
--- /dev/null
+++ b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json
@@ -0,0 +1,15 @@
+
+{
+    "batch_size": 4,
+    "lr": 0.001,
+    "max_epochs":20,
+    "context_frames":12,
+    "sequence_length":24,
+    "loss_fun":"rmse",
+    "shuffle_on_val":false,
+    "recon_weight":0.6
+
+}
+
+
+
diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index c02378a702f5d807210cbef890b507fc99a8c7c7..c6ac709d7d5d82487e60ab915bf7b41cd11ffabc 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -177,8 +177,8 @@ class TrainModel(object):
         self.inputs = self.iterator.get_next()
         #since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model,
         # otherwise the model will raise error
-        if self.dataset == "era5" and self.model == "savp":
-           del self.inputs["T_start"]
+        #if self.dataset == "era5" and self.model == "savp":
+        #   del self.inputs["T_start"]
 
 
 
@@ -231,6 +231,7 @@ class TrainModel(object):
         self.num_examples = self.train_dataset.num_examples_per_epoch()
         self.steps_per_epoch = int(self.num_examples/batch_size)
         self.total_steps = self.steps_per_epoch * max_epochs
+        print("Batch size is {} ; max_epochs is {}; num_samples per epoch is {}; steps_per_epoch is {}, total steps is {}".format(batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps))
 
     def restore(self,sess, checkpoints, restore_to_checkpoint_mapping=None):
         """
@@ -292,11 +293,15 @@ class TrainModel(object):
                 self.create_fetches_for_train()             # In addition to the loss, we fetch the optimizer
                 self.results = sess.run(self.fetches)       # ...and run it here!
                 train_losses.append(self.results["total_loss"])
+                print("t_start for training",self.results["inputs"]["T_start"])
+                print("len of t_start per iteration",len(self.results["inputs"]["T_start"]))
                 #Run and fetch losses for validation data
                 val_handle_eval = sess.run(self.val_handle)
                 self.create_fetches_for_val()
                 self.val_results = sess.run(self.val_fetches,feed_dict={self.train_handle: val_handle_eval})
                 val_losses.append(self.val_results["total_loss"])
+                print("t_start for validation",self.val_results["inputs"]["T_start"])
+                print("len of t_start per iteration",len(self.val_results["inputs"]["T_start"]))
                 self.write_to_summary()
                 self.print_results(step,self.results)
                 timeit_end = time.time()
@@ -333,6 +338,8 @@ class TrainModel(object):
        if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": self.fetches_for_train_convLSTM()
        if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.fetches_for_train_savp()
        if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": self.fetches_for_train_vae()
+       if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":self.fetches_for_train_gan()
+       if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":self.fetches_for_train_convLSTM()
        return self.fetches     
     
     def fetches_for_train_convLSTM(self):
@@ -340,8 +347,7 @@ class TrainModel(object):
         Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users
         """
         self.fetches["total_loss"] = self.video_model.total_loss
- 
-
+        self.fetches["inputs"] = self.video_model.inputs
 
  
     def fetches_for_train_savp(self):
@@ -353,7 +359,7 @@ class TrainModel(object):
         self.fetches["d_loss"] = self.video_model.d_loss
         self.fetches["g_loss"] = self.video_model.g_loss
         self.fetches["total_loss"] = self.video_model.g_loss
-
+        self.fetches["inputs"] = self.video_model.inputs
 
 
     def fetches_for_train_mcnet(self):
@@ -372,15 +378,19 @@ class TrainModel(object):
         self.fetches["recon_loss"] = self.video_model.recon_loss
         self.fetches["total_loss"] = self.video_model.total_loss
 
+    def fetches_for_train_gan(self):
+        self.fetches["total_loss"] = self.video_model.total_loss
+
     def create_fetches_for_val(self):
         """
         Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users
         """
         if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
             self.val_fetches = {"total_loss": self.video_model.g_loss}
+            self.val_fetches["inputs"] = self.video_model.inputs
         else:
             self.val_fetches = {"total_loss": self.video_model.total_loss}
-        
+            self.val_fetches["inputs"] = self.video_model.inputs
         self.val_fetches["summary"] = self.video_model.summary_op
 
     def write_to_summary(self):
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index 68017b4597080b771b00860ab32cf693c0714d73..60dfef0032a7746d57734285a7b86328e0c74f50 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -422,6 +422,7 @@ class Postprocess(TrainModel):
             # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel]
             feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()}
             gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
+
             # sanity check on length of forecast sequence
             assert gen_images.shape[1] == self.sequence_length - 1, \
                 "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(method)
diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py
index ca602a954a107c9217942e7f01e4eae4c68d58bb..5836ab9fce48692252a4dbc44415b4a4f9e2c2c3 100644
--- a/video_prediction_tools/model_modules/model_architectures.py
+++ b/video_prediction_tools/model_modules/model_architectures.py
@@ -14,6 +14,8 @@ def known_models():
         'vae': 'VanillaVAEVideoPredictionModel',
         'convLSTM': 'VanillaConvLstmVideoPredictionModel',
         'mcnet': 'McNetVideoPredictionModel',
+        'gan': "VanillaGANVideoPredictionModel",
+        'convLSTM_gan': "ConvLstmGANVideoPredictionModel",
         'ours_vae_l1': 'SAVPVideoPredictionModel',
         'ours_gan': 'SAVPVideoPredictionModel',
     }
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index d26697cbfd8ab6f94c5316651a33dc772195db28..ce62965a2c92432ffbf739e933279f91b69e355c 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -181,7 +181,9 @@ class ERA5Dataset(object):
             dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
         else:
             dataset = dataset.repeat(self.num_epochs)
-        if self.mode == "val": dataset = dataset.repeat(20)
+
+        if self.mode == "val": dataset = dataset.repeat(20) 
+
         num_parallel_calls = None if shuffle else 1
         dataset = dataset.apply(tf.contrib.data.map_and_batch(
             parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
index 960f608deed07e715190cdecb38efeb2eb4c5ace..2053aeed83a3606804af959e1c422d5cb39723a7 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
@@ -12,6 +12,10 @@ from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
 from .mcnet_model import McNetVideoPredictionModel
 from .test_model import TestModelVideoPredictionModel
 from model_modules.model_architectures import known_models
+from .vanilla_GAN_model import  VanillaGANVideoPredictionModel
+from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel
+
+
 
 def get_model_class(model):
     model_mappings = known_models()
diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ab3a423a001e903ec4ca9fe1bd7ec78e18dc731
--- /dev/null
+++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
@@ -0,0 +1,354 @@
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong,Yanji"
+__date__ = "2021-04-13"
+
+import collections
+import functools
+import itertools
+from collections import OrderedDict
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.util import nest
+from model_modules.video_prediction import ops, flow_ops
+from model_modules.video_prediction.models import BaseVideoPredictionModel
+from model_modules.video_prediction.models import networks
+from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
+from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
+from model_modules.video_prediction.utils import tf_utils
+from datetime import datetime
+from pathlib import Path
+from model_modules.video_prediction.layers import layer_def as ld
+from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
+from tensorflow.contrib.training import HParams
+from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel
+
+class batch_norm(object):
+  def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"):
+    with tf.variable_scope(name):
+      self.epsilon  = epsilon
+      self.momentum = momentum
+      self.name = name
+
+  def __call__(self, x, train=True):
+    return tf.contrib.layers.batch_norm(x,
+                      decay=self.momentum,
+                      updates_collections=None,
+                      epsilon=self.epsilon,
+                      scale=True,
+                      is_training=train,
+                      scope=self.name)
+
+class ConvLstmGANVideoPredictionModel(object):
+    def __init__(self, mode='train', hparams_dict=None):
+        """
+        This is class for building convLSTM_GAN architecture by using updated hparameters
+        args:
+             mode   :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
+             hparams_dict: dict, the dictionary contains the hparaemters names and values
+        """
+        self.mode = mode
+        self.hparams_dict = hparams_dict
+        self.hparams = self.parse_hparams()        
+        self.learning_rate = self.hparams.lr
+        self.total_loss = None
+        self.context_frames = self.hparams.context_frames
+        self.sequence_length = self.hparams.sequence_length
+        self.predict_frames = self.sequence_length - self.context_frames
+        self.max_epochs = self.hparams.max_epochs
+        self.loss_fun = self.hparams.loss_fun
+        self.batch_size = self.hparams.batch_size
+        self.recon_weight = self.hparams.recon_weight
+        self.bd1 = batch_norm(name = "dis1")
+        self.bd2 = batch_norm(name = "dis2")
+        self.bd3 = batch_norm(name = "dis3")   
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+        
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+
+    def get_default_hparams_dict(self):
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            sequence_length : the number of frames in the video sequence 
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+            recon_wegiht    : the weight for reconstrution loss
+            """
+        hparams = dict(
+            context_frames=12,
+            sequence_length=24,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            loss_fun = "cross_entropy",
+            shuffle_on_val= True,
+            recon_weight=0.99,
+          
+         )
+        return hparams
+
+
+    def build_graph(self, x):
+        self.is_build_graph = False
+        self.inputs = x
+        self.x = x["images"]
+        self.width = self.x.shape.as_list()[3]
+        self.height = self.x.shape.as_list()[2]
+        self.channels = self.x.shape.as_list()[4]
+        self.global_step = tf.train.get_or_create_global_step()
+        original_global_variables = tf.global_variables()
+        # Architecture
+        self.define_gan()
+        #This is the loss function (RMSE):
+        #This is loss function only for 1 channel (temperature RMSE)
+        #generator los
+        self.total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss
+        self.D_loss =  (1-self.recon_weight) * self.D_loss
+        if self.mode == "train":
+            if self.recon_weight == 1:
+                print("Only train generator- convLSTM") 
+                self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) 
+            else:
+                print("Training distriminator")
+                self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars)
+                with tf.control_dependencies([self.D_solver]):
+                    print("Training generator....")
+                    self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars)
+                with tf.control_dependencies([self.G_solver]):
+                    self.train_op = tf.assign_add(self.global_step,1)
+        else:
+           self.train_op = None 
+
+        self.outputs = {}
+        self.outputs["gen_images"] = self.gen_images
+        self.outputs["total_loss"] = self.total_loss
+        # Summary op
+        tf.summary.scalar("total_loss", self.total_loss)
+        tf.summary.scalar("D_loss", self.D_loss)
+        tf.summary.scalar("G_loss", self.G_loss)
+        tf.summary.scalar("D_loss_fake", self.D_loss_fake) 
+        tf.summary.scalar("D_loss_real", self.D_loss_real)
+        tf.summary.scalar("recon_loss",self.recon_loss)
+        self.summary_op = tf.summary.merge_all()
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        self.is_build_graph = True
+        return self.is_build_graph 
+    
+    def get_noise(self):
+        """
+        Function for creating noise: Given the dimensions (n_batch,n_seq, n_height, n_width, channel)
+        """ 
+        self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.batch_size, self.sequence_length, self.height, self.width, self.channels])
+        return self.noise
+     
+    @staticmethod
+    def lrelu(x, leak=0.2, name="lrelu"):
+        return tf.maximum(x, leak*x)
+
+    @staticmethod    
+    def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
+        shape = input_.get_shape().as_list()
+
+        with tf.variable_scope(scope or "Linear"):
+            matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32,
+                     tf.random_normal_initializer(stddev=stddev))
+            bias = tf.get_variable("bias", [output_size],
+            initializer=tf.constant_initializer(bias_start))
+            if with_w:
+                return tf.matmul(input_, matrix) + bias, matrix, bias
+            else:
+                return tf.matmul(input_, matrix) + bias
+     
+    @staticmethod
+    def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"):
+        with tf.variable_scope(name):
+            w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
+                  initializer=tf.truncated_normal_initializer(stddev=stddev))
+            conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
+
+            biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
+            conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
+
+        return conv
+
+    @staticmethod
+    def bn(x, scope):
+        return tf.contrib.layers.batch_norm(x,
+                                        decay=0.9,
+                                        updates_collections=None,
+                                        epsilon=1e-5,
+                                        scale=True,
+                                        scope=scope)
+
+    def generator(self):
+        """
+        Function to build up the generator architecture
+        args:
+            input images: a input tensor with dimension (n_batch,sequence_length,height,width,channel)
+        """
+        with tf.variable_scope("generator",reuse=tf.AUTO_REUSE):
+            layer_gen = self.convLSTM_network(self.x)
+            layer_gen_pred = layer_gen[:,self.context_frames-1:,:,:,:]
+        return layer_gen
+
+
+    def discriminator(self,vid):
+        """
+        Function that get discriminator architecture      
+        """
+        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
+            conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis1")
+            conv1 = ConvLstmGANVideoPredictionModel.lrelu(conv1)
+            conv2 = tf.layers.conv3d(conv1,128,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis2")
+            conv2 = ConvLstmGANVideoPredictionModel.lrelu(self.bd1(conv2))
+            conv3 = tf.layers.conv3d(conv2,256,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis3")
+            conv3 = ConvLstmGANVideoPredictionModel.lrelu(self.bd2(conv3))
+            conv4 = tf.layers.conv3d(conv3,512,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis4")
+            conv4 = ConvLstmGANVideoPredictionModel.lrelu(self.bd3(conv4))
+            conv5 = tf.layers.conv3d(conv4,1,kernel_size=[2,4,4],strides=[1,1,1],padding="SAME",name="dis5")
+            conv5 = tf.reshape(conv5, [-1,1])
+            conv5sigmoid = tf.nn.sigmoid(conv5)
+            return conv5sigmoid,conv5
+
+    def discriminator0(self,image):
+        """
+        Function that get discriminator architecture      
+        """
+        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
+            layer_disc = self.convLSTM_network(image)
+            layer_disc = layer_disc[:,self.context_frames-1:self.context_frames,:,:, 0:1]
+        return layer_disc
+
+    def discriminator1(self,sequence):
+        """
+        https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py
+        Function that give the possibility of a sequence of frames is ture of false 
+        the input squence shape is like [batch_size,time_seq_length,height,width,channel]  (e.g., self.x[:,:self.context_frames,:,:,:])
+        """
+        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
+            print(sequence.shape)
+            x = sequence[:,:,:,:,0:1] # extract targeted variable
+            x = tf.transpose(x, [0,2,3,1,4]) # sequence shape is like: [batch_size,height,width,time_seq_length]
+            x = tf.reshape(x,[x.shape[0],x.shape[1],x.shape[2],x.shape[3]])
+            print(x.shape)
+            net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.conv2d(x, 64, 4, 4, 2, 2, name='d_conv1'))
+            net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'),scope='d_bn2'))
+            net = tf.reshape(net, [self.batch_size, -1])
+            net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.linear(net, 1024, scope='d_fc3'),scope='d_bn3'))
+            out_logit = ConvLstmGANVideoPredictionModel.linear(net, 1, scope='d_fc4')
+            out = tf.nn.sigmoid(out_logit)
+            print(out.shape)
+        return out, out_logit
+
+    def get_disc_loss(self):
+        """
+        Return the loss of discriminator given inputs
+        """
+          
+        real_labels = tf.ones_like(self.D_real)
+        gen_labels = tf.zeros_like(self.D_fake)
+        self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels))
+        self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=gen_labels))
+        self.D_loss = self.D_loss_real + self.D_loss_fake
+        return self.D_loss
+
+
+    def get_gen_loss(self):
+        """
+        Param:
+	    num_images: the number of images the generator should produce, which is also the lenght of the real image
+            z_dim     : the dimension of the noise vector, a scalar
+        Return the loss of generator given inputs
+        """
+        real_labels = tf.ones_like(self.D_fake)
+        self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=real_labels))
+        return self.G_loss         
+   
+    def get_vars(self):
+        """
+        Get trainable variables from discriminator and generator
+        """
+        print("trinable_varialbes", len(tf.trainable_variables()))
+        self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
+        self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
+        print("self.disc_vars",self.disc_vars)
+        print("self.gen_vars",self.gen_vars)
+ 
+  
+    def define_gan(self):
+        """
+        Define gan architectures
+        """
+        self.noise = self.get_noise()
+        self.gen_images = self.generator()
+        #!!!! the input of discriminator should be changed when use different discriminators
+        self.D_real, self.D_real_logits = self.discriminator(self.x[:,self.context_frames:,:,:,:])
+        self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,self.context_frames-1:,:,:,:])
+        self.get_gen_loss()
+        self.get_disc_loss()
+        self.get_vars()
+        if self.loss_fun == "rmse":
+            self.recon_loss = tf.reduce_mean(tf.square(self.x[:, self.context_frames:,:,:,0] - self.gen_images[:,self.context_frames-1:,:,:,0]))
+        elif self.loss_fun == "cross_entropy":
+            x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
+            x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,self.context_frames-1:,:,:,0],[-1])
+            bce = tf.keras.losses.BinaryCrossentropy()
+            self.recon_loss = bce(x_flatten,x_hat_predict_frames_flatten)
+        else:
+            raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")   
+
+
+    @staticmethod
+    def convLSTM_cell(inputs, hidden):
+        y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
+        channels = inputs.get_shape()[-1]
+        # conv lstm cell
+        cell_shape = y_0.get_shape().as_list()
+        channels = cell_shape[-1]
+        with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
+            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64)
+            if hidden is None:
+                hidden = cell.zero_state(y_0, tf.float32)
+            output, hidden = cell(y_0, hidden)
+        output_shape = output.get_shape().as_list()
+        z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
+        #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction
+        x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
+        print('x_hat shape is: ',x_hat.shape)
+        return x_hat, hidden
+
+    def convLSTM_network(self,x):
+        network_template = tf.make_template('network',VanillaConvLstmVideoPredictionModel.convLSTM_cell)  # make the template to share the variables
+        # create network
+        x_hat = []
+        
+        #This is for training (optimization of convLSTM layer)
+        hidden_g = None
+        for i in range(self.sequence_length-1):
+            if i < self.context_frames:
+                x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g)
+            else:
+                x_1_g, hidden_g = network_template(x_1_g, hidden_g)
+            x_hat.append(x_1_g)
+
+        # pack them all together
+        x_hat = tf.stack(x_hat)
+        self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim  ???? yan: why?
+        print('self.x_hat shape is: ',self.x_hat.shape)
+        return self.x_hat
+     
+   
+   
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0b0d61edcc2464492fbd00e733ff4ce0130c04a
--- /dev/null
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py
@@ -0,0 +1,242 @@
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong"
+__date__ = "2021=01-05"
+
+
+
+"""
+This code implement take the following as references:
+1) https://stackabuse.com/introduction-to-gans-with-python-and-tensorflow/
+2) cousera GAN courses
+3) https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py
+"""
+import collections
+import functools
+import itertools
+from collections import OrderedDict
+import numpy as np
+import tensorflow as tf
+from tensorflow.python.util import nest
+from model_modules.video_prediction import ops, flow_ops
+from model_modules.video_prediction.models import BaseVideoPredictionModel
+from model_modules.video_prediction.models import networks
+from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
+from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
+from model_modules.video_prediction.utils import tf_utils
+from datetime import datetime
+from pathlib import Path
+from model_modules.video_prediction.layers import layer_def as ld
+from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
+from tensorflow.contrib.training import HParams
+
+class VanillaGANVideoPredictionModel(object):
+    def __init__(self, mode='train', hparams_dict=None):
+        """
+        This is class for building vanilla GAN architecture by using updated hparameters
+        args:
+             mode   :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
+             hparams_dict: dict, the dictionary contains the hparaemters names and values
+        """
+        self.mode = mode
+        self.hparams_dict = hparams_dict
+        self.hparams = self.parse_hparams()        
+        self.learning_rate = self.hparams.lr
+        self.total_loss = None
+        self.context_frames = self.hparams.context_frames
+        self.sequence_length = self.hparams.sequence_length
+        self.predict_frames = self.sequence_length - self.context_frames
+        self.max_epochs = self.hparams.max_epochs
+        self.loss_fun = self.hparams.loss_fun
+        self.batch_size = self.hparams.batch_size
+        self.z_dim = self.hparams.z_dim #dim of noise-vector
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+        
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+
+    def get_default_hparams_dict(self):
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            sequence_length : the number of frames in the video sequence 
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+        """
+        hparams = dict(
+            context_frames=12,
+            sequence_length=24,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            loss_fun = "cross_entropy",
+            shuffle_on_val= True,
+            z_dim = 32,
+         )
+        return hparams
+
+
+    def build_graph(self, x):
+        self.is_build_graph = False
+        self.x = x["images"]
+        self.width = self.x.shape.as_list()[3]
+        self.height = self.x.shape.as_list()[2]
+        self.channels = self.x.shape.as_list()[4]
+        self.n_samples = self.x.shape.as_list()[0] * self.x.shape.as_list()[1]
+        self.x = tf.reshape(self.x, [-1, self.height,self.width,self.channels]) 
+        self.global_step = tf.train.get_or_create_global_step()
+        original_global_variables = tf.global_variables()
+        # Architecture
+        self.define_gan()
+        #This is the loss function (RMSE):
+        #This is loss function only for 1 channel (temperature RMSE)
+        if self.mode == "train":
+            self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars)
+            with tf.control_dependencies([self.D_solver]):
+                self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.G_loss, var_list=self.gen_vars)
+            with tf.control_dependencies([self.G_solver]):
+                self.train_op = tf.assign_add(self.global_step,1)
+        else:
+           self.train_op = None 
+        self.total_loss = self.G_loss + self.D_loss 
+        self.outputs = {}
+        self.outputs["gen_images"] = self.gen_images
+        self.outputs["total_loss"] = self.total_loss
+        # Summary op
+        self.loss_summary = tf.summary.scalar("total_loss", self.G_loss + self.D_loss)
+        self.summary_op = tf.summary.merge_all()
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        self.is_build_graph = True
+        return self.is_build_graph 
+    
+    def get_noise(self):
+        """
+        Function for creating noise: Given the dimensions (n_samples,z_dim)
+        """ 
+        self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.n_samples, self.height, self.width, self.channels])
+        return self.noise
+
+    def get_generator_block(self,inputs,output_dim,idx):
+       
+        """
+        Generator Block
+        Function for return a neural network of the generator given input and output dimensions
+        args:
+            inputs : the  input vector
+            output_dim: the dimeniosn of output vector
+        return:
+             a generator neural network layer, with a convolutional layers followed by batch normalization and a relu activation
+       
+        """
+        output1 = ld.conv_layer(inputs,kernel_size=2,stride=1,num_features=output_dim,idx=idx,activate="linear")
+        output2 = ld.bn_layers(output1,idx,is_training=False)
+        output3 = tf.nn.relu(output2)
+        return output3
+
+
+    def generator(self,hidden_dim):
+        """
+        Function to build up the generator architecture
+        args:
+            noise: a noise tensor with dimension (n_samples,height,width,channel)
+            hidden_dim: the inner dimension
+        """
+        with tf.variable_scope("generator",reuse=tf.AUTO_REUSE):
+            layer1 = self.get_generator_block(self.noise,hidden_dim,1)
+            layer2 = self.get_generator_block(layer1,hidden_dim*2,2)
+            layer3 = self.get_generator_block(layer2,hidden_dim*4,3)
+            layer4 = self.get_generator_block(layer3,hidden_dim*8,4)
+            layer5 = ld.conv_layer(layer4,kernel_size=2,stride=1,num_features=self.channels,idx=5,activate="linear")
+            layer6 = tf.nn.sigmoid(layer5,name="6_conv")
+        print("layer6",layer6)
+        return layer6
+
+
+
+    def get_discriminator_block(self,inputs,output_dim,idx):
+
+        """
+        Distriminator block
+        Function for ruturn a neural network of a descriminator given input and output dimensions
+
+        args:
+           inputs : the dimension of input vector
+           output_dim: the dimension of output dim
+           idx:      : the index for the namespace of this block
+        Return:
+           a distriminator neural network layer with a convolutional layers followed by a leakyRelu function 
+        """
+        output1 = ld.conv_layer(inputs,2,stride=1,num_features=output_dim,idx=idx,activate="linear")
+        output2 = tf.nn.leaky_relu(output1)
+        return output2
+
+
+    def discriminator(self,image,hidden_dim):
+        """
+        Function that get discriminator architecture      
+        """
+        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
+            layer1 = self.get_discriminator_block(image,hidden_dim,idx=1)
+            layer2 = self.get_discriminator_block(layer1,hidden_dim*4,idx=2)
+            layer3 = self.get_discriminator_block(layer2,hidden_dim*2,idx=3)
+            layer4 = self.get_discriminator_block(layer3, self.channels,idx=4)
+            layer5 = tf.nn.sigmoid(layer4)
+        return layer5
+
+
+    def get_disc_loss(self):
+        """
+        Return the loss of discriminator given inputs
+        """
+          
+        real_labels = tf.ones_like(self.D_real)
+        gen_labels = tf.zeros_like(self.D_fake)
+        D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real, labels=real_labels))
+        D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=gen_labels))
+        self.D_loss = D_loss_real + D_loss_fake
+        return self.D_loss
+
+
+    def get_gen_loss(self):
+        """
+        Param:
+	    num_images: the number of images the generator should produce, which is also the lenght of the real image
+            z_dim     : the dimension of the noise vector, a scalar
+        Return the loss of generator given inputs
+        """
+        real_labels = tf.ones_like(self.gen_images)
+        self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=real_labels))
+        return self.G_loss         
+   
+    def get_vars(self):
+        """
+        Get trainable variables from discriminator and generator
+        """
+        self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
+        self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
+       
+ 
+  
+    def define_gan(self):
+        """
+        Define gan architectures
+        """
+        self.noise = self.get_noise()
+        self.gen_images = self.generator(hidden_dim=8)
+        self.D_real = self.discriminator(self.x,hidden_dim=8)
+        self.D_fake = self.discriminator(self.gen_images,hidden_dim=8)
+        self.get_gen_loss()
+        self.get_disc_loss()
+        self.get_vars()
+      
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
index 58172bca0401cdc2b2a4353ac2aeee092d59774a..1780b2e8439341320fe5726dab8f7174225a5956 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
@@ -8,6 +8,8 @@ from model_modules.video_prediction.layers import layer_def as ld
 from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
 from tensorflow.contrib.training import HParams
 
+
+
 class VanillaConvLstmVideoPredictionModel(object):
     def __init__(self, mode='train', hparams_dict=None):
         """
@@ -65,6 +67,7 @@ class VanillaConvLstmVideoPredictionModel(object):
 
     def build_graph(self, x):
         self.is_build_graph = False
+        self.inputs = x
         self.x = x["images"]
         self.global_step = tf.train.get_or_create_global_step()
         original_global_variables = tf.global_variables()