diff --git a/video_prediction_tools/data_extraction/extract_weatherbench.py b/video_prediction_tools/data_extraction/extract_weatherbench.py
index a5798c188d62b69d93ef22efa0253a80bf2b031a..ca126a8ccd301c9adb00f0cc54f71e3913bc6a0a 100644
--- a/video_prediction_tools/data_extraction/extract_weatherbench.py
+++ b/video_prediction_tools/data_extraction/extract_weatherbench.py
@@ -1,16 +1,12 @@
-import os, glob
-import logging
 
+import logging
 from zipfile import ZipFile
 from typing import Union
 from pathlib import Path
 import multiprocessing as mp
 import itertools as it
-import sys
-
 import pandas as pd
 import xarray as xr
-
 from utils.dataset_utils import get_filename_template
 
 logging.basicConfig(level=logging.DEBUG)
diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 7dad59d4fc70f0abb82dfbe6fa3be466be8c1f03..0ec4095e831b558e0c4bc7f8abf6432f948115d4 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -12,16 +12,16 @@ __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong, Michael Langguth"
 __date__ = "2020-10-22"
 
-import os, glob
+import os, glob, sys
 import argparse
 import errno
 import json
-from typing import Union, List
 import random
 import time
 import numpy as np
 import xarray as xr
 import tensorflow as tf
+sys.path.append("../")
 from model_modules.video_prediction import models
 from model_modules.video_prediction.datasets import get_dataset
 import matplotlib.pyplot as plt
@@ -31,27 +31,39 @@ from general_utils import *
 import math
 import shutil
 from pathlib import Path
+from tensorflow.python.keras.utils.layer_utils import count_params
+tf.config.experimental_run_functions_eagerly(True)
+
 
 class TrainModel(object):
-    def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None,
-                 model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset_name: str = None,
-                 gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.001,
-                 frac_start_save: float = None, frac_intv_save: float = None):
+    def __init__(self, input_dir: str = None,
+                 output_dir: str = None,
+                 datasplit_dict: str = None,
+                 model_hparams_dict: str = None,
+                 model: str = None,
+                 checkpoint: str = None,
+                 dataset_name: str = None,
+                 gpu_mem_frac: float = 1.,
+                 seed: int = None,
+                 args=None,
+                 diag_intv_frac: float = 0.001,
+                 frac_start_save: float = None,
+                 frac_intv_save: float = None):
         """
         Class instance for training the models
-        :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located
-        :param output_dir: directory where all the output is saved (e.g. model, JSON-files, training curves etc.)
-        :param datasplit_dict: JSON-file for defining data splitting
-        :param model_hparams_dict: JSON-file of model hyperparameters
-        :param model: model class name
-        :param checkpoint: checkpoint directory (pre-trained models)
-        :param dataset: dataset class name
-        :param gpu_mem_frac: fraction of GPU memory to be preallocated
-        :param seed: seed of the randomizers
-        :param args: list of arguments passed
-        :param diag_intv_frac: interval for diagnozing the model (create loss-curves and save pickle-file with losses)
-        :param frac_start_save: fraction of total iterations steps to start checkpointing the model
-        :param frac_intv_save: fraction of total iterations steps for checkpointing the model
+        :param input_dir          : parent directory under which "pickle" and "tfrecords" files directiory are located
+        :param output_dir         : directory where all the output is saved (e.g. model, JSON-files, training curves etc.)
+        :param datasplit_dict     : JSON-file for defining data splitting
+        :param model_hparams_dict : JSON-file of model hyperparameters
+        :param model              : model class name
+        :param checkpoint         : checkpoint directory (pre-trained models)
+        :param dataset            : dataset class name
+        :param gpu_mem_frac       : fraction of GPU memory to be preallocated
+        :param seed               : seed of the randomizers
+        :param args               : list of arguments passed
+        :param diag_intv_frac     : interval for diagnosing the model (create loss-curves and save pickle-file with losses)
+        :param frac_start_save    : fraction of total iterations steps to start checkpointing the model
+        :param frac_intv_save     : fraction of total iterations steps for checkpointing the model
         """
         self.input_dir = Path(input_dir).resolve(strict=False)
         self.output_dir = Path(output_dir).resolve(strict=False)
@@ -66,7 +78,7 @@ class TrainModel(object):
         self.diag_intv_frac = diag_intv_frac
         self.frac_start_save = frac_start_save
         self.frac_intv_save = frac_intv_save
-        # for diagnozing and saving the model during training
+
         self.saver_loss = None         # set in create_fetches_for_train-method
         self.saver_loss_name = None    # set in create_fetches_for_train-method 
         self.saver_loss_dict = None    # set in create_fetches_for_train-method if loss of interest is nested 
@@ -77,21 +89,20 @@ class TrainModel(object):
         self.get_model_hparams_dict()
         self.load_params_from_checkpoints_dir()
         self.setup_datasets()
-        self.make_dataset_iterator()
         self.setup_model()
-        self.setup_graph()
-        self.save_dataset_model_params_to_checkpoint_dir(dataset=self.dataset, video_model=self.video_model) # TODO: resolve potetial incompatibility
-        self.count_parameters()
-        self.create_saver_and_writer()
-        self.setup_gpu_config()
-        self.calculate_checkpoint_saver_conf()
+        self.save_dataset_model_params_to_checkpoint_dir(dataset=self.dataset,
+                                                         video_model=self.video_model)
+        #self.count_parameters()
+        # self.create_saver_and_writer()
+        # self.setup_gpu_config()
+        # self.calculate_checkpoint_saver_conf()
 
     def set_seed(self):
         """
         Set seed to control the same train/val/testing dataset for the same seed
         """
         if self.seed is not None:
-            tf.set_random_seed(self.seed)
+            tf.random.set_seed(self.seed)
             np.random.seed(self.seed)
             random.seed(self.seed)
 
@@ -113,10 +124,10 @@ class TrainModel(object):
         """
         if self.model_hparams_dict:
             with open(self.model_hparams_dict, 'r') as f:
-                print("self.model_hparams_dict",self.model_hparams_dict)
                 self.model_hparams_dict_load = json.loads(f.read())
         else:
-            raise FileNotFoundError("hparam directory doesn't exist! please check {}!".format(self.model_hparams_dict))
+            raise FileNotFoundError("hparam directory doesn't exist! "
+                                    "please check {}!".format(self.model_hparams_dict))
 
         return self.model_hparams_dict_load
 
@@ -124,7 +135,8 @@ class TrainModel(object):
         """
         If checkpoint is none, load and read the json files of datasplit_config, and hparam_config,
         and use the corresponding parameters.
-        If the checkpoint is given, the configuration of dataset, model and options in the checkpoint dir will be
+        If the checkpoint is given, the configuration of dataset,
+        model and options in the checkpoint dir will be
         restored and used for continue training.
         """
         method = TrainModel.load_params_from_checkpoints_dir.__name__
@@ -144,7 +156,7 @@ class TrainModel(object):
                     self.model = self.model or self.options['model']
             except FileNotFoundError:
                 print("%{0}: options.json does not exist in {1}".format(method, self.checkpoint_dir))
-            # loading hyperparameters from checkpoint
+
             try:
                 with open(os.path.join(self.checkpoint_dir, "model_hparams.json")) as f:
                     self.model_hparams_dict_load.update(json.loads(f.read()))
@@ -154,15 +166,21 @@ class TrainModel(object):
     def setup_datasets(self):
         """
         Setup train and val dataset instance with the corresponding data split configuration.
-        Simultaneously, sequence_length is attached to the hyperparameter dictionary.
+        Simultaneously, sequence_length is attached to the hyper-parameter dictionary.
         """
-        # get some parameters from the model hyperparameters
+        # get some parameters from the model hyper-parameters
         self.batch_size = self.model_hparams_dict_load["batch_size"]
         self.max_epochs = self.model_hparams_dict_load["max_epochs"]
         # create dataset instance
-
-        self.dataset = get_dataset(self.dataset_name, input_dir=self.input_dir, output_dir=self.output_dir, datasplit_path=self.datasplit_dict, hparams_path=self.model_hparams_dict, seed=self.seed)
-        
+        self.dataset = get_dataset(self.dataset_name,
+                                   input_dir=self.input_dir,
+                                   output_dir=self.output_dir,
+                                   datasplit_path=self.datasplit_dict,
+                                   hparams_path=self.model_hparams_dict,
+                                   seed=self.seed)
+        self.train_dataset = self.dataset.make_dataset(mode="train")
+        self.val_dataset = self.dataset.make_dataset(mode="val")
+        self.test_dataset = self.dataset.make_dataset(mode="test")
         self.calculate_samples_and_epochs()
         self.model_hparams_dict_load.update({"sequence_length": self.dataset.sequence_length})
 
@@ -172,34 +190,11 @@ class TrainModel(object):
         :param mode: "train" used the model graph in train process;  "test" for postprocessing step
         """
         VideoPredictionModel = models.get_model_class(self.model)
-        self.video_model = VideoPredictionModel(hparams_dict_config=self.model_hparams_dict, mode=mode)
+        self.video_model = VideoPredictionModel(hparams_dict_config=
+                                                self.model_hparams_dict,
+                                                mode=mode)
+        self.video_model.compile([24, 10, 21, 2])
 
-    def setup_graph(self):
-        """
-        build model graph
-        """
-        self.video_model.build_graph(self.inputs)
-        
-    def make_dataset_iterator(self):
-        """
-        Prepare the dataset interator for training and validation
-        """
-        self.batch_size = self.model_hparams_dict_load["batch_size"]
-        train_tf_dataset = self.dataset.make_training()
-        train_iterator = train_tf_dataset.make_one_shot_iterator()
-        # The `Iterator.string_handle()` method returns a tensor that can be evaluated
-        # and used to feed the `handle` placeholder.
-        self.train_handle = train_iterator.string_handle()
-        val_tf_dataset = self.dataset.make_validation()
-        val_iterator = val_tf_dataset.make_one_shot_iterator()
-        self.val_handle = val_iterator.string_handle()
-        self.iterator = tf.data.Iterator.from_string_handle(
-            self.train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
-        self.inputs = self.iterator.get_next()
-        # since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP
-        # Otherwise an error will be risen by SAVP 
-        if self.dataset_name == "era5" and self.model == "savp":
-            del self.inputs["T_start"]
 
     def save_dataset_model_params_to_checkpoint_dir(self, dataset, video_model):
         """
@@ -210,40 +205,40 @@ class TrainModel(object):
         with open(os.path.join(self.output_dir, "dataset_hparams.json"), "w") as f:
             f.write(json.dumps(dataset.hparams, sort_keys=True, indent=4))
         with open(os.path.join(self.output_dir, "model_hparams.json"), "w") as f:
-            print("video_model.get_hparams",video_model.get_hparams)
             f.write(json.dumps(video_model.get_hparams, sort_keys=True, indent=4))
-        #with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f:
-        #   f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4))
+
 
     def count_parameters(self):
         """
         Count the paramteres of the model
-        """ 
-        with tf.name_scope("parameter_count"):
-            # exclude trainable variables that are replicates (used in multi-gpu setting)
-            self.trainable_variables = set(tf.trainable_variables()) & set(self.video_model.saveable_variables)
-            self.parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in self.trainable_variables])
+        """
+        self.trainable_variables = count_params(self.video_model.model.trainable_weights)
+        self.non_trainable_count = count_params(self.video_model.model.non_trainable_weights)
+        self.parameter_count = self.trainable_variables + self.non_trainable_count
 
     def create_saver_and_writer(self):
         """
         Create saver to save the models latest checkpoints, and a summery writer to store the train/val metrics  
         """
-        self.saver = tf.train.Saver(var_list=self.video_model.saveable_variables, max_to_keep=None)
+        self.saver = tf.train.Saver(var_list=self.video_model.saveable_variables,
+                                    max_to_keep=None)
         self.summary_writer = tf.summary.FileWriter(self.output_dir)
 
     def setup_gpu_config(self):
         """
         Setup GPU options 
         """
-        self.gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.gpu_mem_frac, allow_growth=True)
-        self.config = tf.ConfigProto(gpu_options=self.gpu_options, allow_soft_placement=True)
+        self.gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.gpu_mem_frac,
+                                         allow_growth=True)
+        self.config = tf.ConfigProto(gpu_options=self.gpu_options,
+                                     allow_soft_placement=True)
 
     def calculate_samples_and_epochs(self):
         """
         Calculate the number of samples for train dataset, which is used for each epoch training
         Calculate the iterations (samples multiple by max_epochs) for training.
         """
-        method = TrainModel.calculate_samples_and_epochs.__name__        
+        method = TrainModel.calculate_samples_and_epochs.__name__
 
         self.num_examples = self.dataset.num_training_samples
         self.steps_per_epoch = int(self.num_examples/self.batch_size)
@@ -255,11 +250,14 @@ class TrainModel(object):
         else:
             pass
         print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}"
-              .format(method, self.batch_size, self.max_epochs, self.num_examples, self.steps_per_epoch,
+              .format(method,
+                      self.batch_size,
+                      self.max_epochs,
+                      self.num_examples,
+                      self.steps_per_epoch,
                       self.total_steps))
 
 
-
     def calculate_checkpoint_saver_conf(self):
         """
         Calculate the start step for saving the checkpoint, and the frequences steps to save model
@@ -268,17 +266,20 @@ class TrainModel(object):
         method = TrainModel.calculate_checkpoint_saver_conf.__name__
 
         if not hasattr(self, "total_steps"):
-            raise RuntimeError("%{0} self.total_steps is still unset. Run calculate_samples_and_epochs beforehand"
+            raise RuntimeError("%{0} self.total_steps is still unset. "
+                               "Run calculate_samples_and_epochs beforehand"
                                .format(method))
-        if self.frac_intv_save > 1 or self.frac_intv_save<0 :
-            raise ValueError("%{0}: frac_intv_save must be less than 1 and larger than 0".format(method))
+        if self.frac_intv_save > 1 or self.frac_intv_save<0:
+            raise ValueError("%{0}: frac_intv_save must be less than 1 "
+                             "and larger than 0".format(method))
         if self.frac_start_save > 1 or self.frac_start_save < 0:
-            raise ValueError("%{0}: frac_start_save must be less than 1 and larger than 0".format(method))
+            raise ValueError("%{0}: frac_start_save must be less than 1 "
+                             "and larger than 0".format(method))
 
         self.chp_start_step = int(math.ceil(self.total_steps * self.frac_start_save))
         self.chp_intv_step = int(math.ceil(self.total_steps * self.frac_intv_save))
         print("%{0}: Model will be saved after step {1:d} at each {2:d} interval step "
-              .format(method, self.chp_start_step,self.chp_intv_step))
+              .format(method, self.chp_start_step, self.chp_intv_step))
 
     def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
         """
@@ -292,11 +293,10 @@ class TrainModel(object):
             print("%{0}: There are no checkpoints in the dir {1}".format(method, checkpoints))
         else:
             var_list = self.video_model.saveable_variables
-            # possibly restore from multiple checkpoints. useful if subset of weights
-            # (e.g. generator or discriminator) are on different checkpoints.
+
             if not isinstance(checkpoints, (list, tuple)):
                 checkpoints = [checkpoints]
-            # automatically skip global_step if more than one checkpoint is provided
+
             skip_global_step = len(checkpoints) > 1
             savers = []
             for checkpoint in checkpoints:
@@ -314,14 +314,15 @@ class TrainModel(object):
         """
         if self.checkpoint is None:
             train_losses, val_losses = [], []
-        elif os.path.isdir(self.checkpoint) and (not os.path.exists(os.path.join(self.output_dir, "checkpoint"))):
-            train_losses,val_losses = [], []
+        elif os.path.isdir(self.checkpoint) and (not os.path.exists(os.path.join(self.output_dir,
+                                                                                 "checkpoint"))):
+            train_losses, val_losses = [], []
         else:
             with open(os.path.join(self.output_dir, "train_losses.pkl"), "rb") as f:
                 train_losses = pkl.load(f)
             with open(os.path.join(self.output_dir, "val_losses.pkl"), "rb") as f:
                 val_losses = pkl.load(f)
-        return train_losses,val_losses
+        return train_losses, val_losses
 
     def create_checkpoints_folder(self, step:int=None):
         """
@@ -337,96 +338,17 @@ class TrainModel(object):
         """
         Start session and train the model by looping over all iteration steps
         """
-        method = TrainModel.train_model.__name__
-
-        self.global_step = tf.train.get_or_create_global_step()
-        with tf.Session(config=self.config) as sess:
-            sess.run(tf.global_variables_initializer())
-            sess.run(tf.local_variables_initializer())
-            self.restore(sess, self.checkpoint)
-            start_step = sess.run(self.global_step)
-            print("%{0}: Iteration starts at step {1}".format(method, start_step))
-            # start at one step earlier to log everything without doing any training
-            # step is relative to the start_step
-            train_losses, val_losses = self.restore_train_val_losses()
-            # initialize auxiliary variables
-            time_per_iteration = []
-            run_start_time = time.time()
-            # perform iteration
-            for step in range(start_step, self.total_steps):
-                timeit_start = time.time()
-                # Run training data
-                self.create_fetches_for_train()             # In addition to the loss, we fetch the optimizer
-                self.results = sess.run(self.fetches)       # ...and run it here!
-                # Note: For SAVP, the obtained loss is a list where the first element is of interest, for convLSTM,
-                # it's just a number. Thus, with ensure_list(<losses>)[0], we can handle both
-                train_losses.append(ensure_list(self.results[self.saver_loss])[0])
-                # run and fetch losses for validation data
-                val_handle_eval = sess.run(self.val_handle)
-                self.create_fetches_for_val()
-                self.val_results = sess.run(self.val_fetches, feed_dict={self.train_handle: val_handle_eval})
-                val_losses.append(ensure_list(self.val_results[self.saver_loss])[0])
-                self.write_to_summary()
-                self.print_results(step, self.results)
-                # track iteration time
-                time_iter = time.time() - timeit_start
-                time_per_iteration.append(time_iter)
-                print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter))
-                if (step >= self.chp_start_step and (step-self.chp_start_step)%self.chp_intv_step == 0) or \
-                    step == self.total_steps - 1:
-                    #create a checkpoint folder for step
-                    full_dir_name = self.create_checkpoints_folder(step=step)
-                    self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step)
-
-                # pickle file and plots are always created
-                if step % self.diag_intv_step == 0 or step == self.total_steps - 1:
-                    TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir)
-                    TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir)
-
-            # Final diagnostics: training track time and save to pickle-files)
-            train_time = time.time() - run_start_time
-
-            avg_time_first_epoch = np.mean(time_per_iteration[:self.steps_per_epoch])
-            avg_time_non_first_epoch = np.mean(time_per_iteration[self.steps_per_epoch:])
-            results_dict = {"train_time": train_time, "total_steps": self.total_steps,
-                            "avg_time_first_epoch": avg_time_first_epoch,
-                            "avg_time_non_first_epoch":avg_time_non_first_epoch}
-
-            TrainModel.save_results_to_dict(results_dict, self.output_dir)
-
-            print("%{0}: Training loss decreased from {1:.6f} to {2:.6f}:"
-                  .format(method, np.mean(train_losses[0:10]), np.mean(train_losses[-self.diag_intv_step:])))
-            print("%{0}: Validation loss decreased from {1:.6f} to {2:.6f}:"
-                  .format(method, np.mean(val_losses[0:10]), np.mean(val_losses[-self.diag_intv_step:])))
-            print("%{0}: Training finished".format(method))
-            print("%{0}: Total training time: {1:.2f} min".format(method, train_time/60.))
-            print("%{0}: The average of training time for the first epoch: {1:.2f} sec".format(method, avg_time_first_epoch))
-            print("%{0}: The average of training time for after first epoch: {1:.2f} sec".format(method,avg_time_non_first_epoch))
-            return train_time, time_per_iteration
- 
-    def create_fetches_for_train(self):
-        """
-        Fetch variables in the graph, this can be custermized based on models and also the needs of users
-        """
-        # This is the basic fetch for all the models
-        fetch_list = ["train_op", "summary_op", "global_step","total_loss"]
+        global_step = 1
+        # initialize auxiliary variables
+        run_start_time = time.time()
+        for epoch in range(self.max_epochs):
+            for step, inputs in enumerate(self.train_dataset):
+                self.video_model.train_step(inputs, global_step)
+                print("loss", self.video_model.loss_value)
+                global_step = +1
+
+            timeit_start = time.time()
 
-        # Append fetches depending on model to be trained
-        if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
-            self.saver_loss = fetch_list[-1]  
-            self.saver_loss_name = "Loss"
-        if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
-            fetch_list = fetch_list + ["inputs"]
-            self.saver_loss = fetch_list[-1]
-            self.saver_loss_name = "Total loss"
-        if self.video_model.__class__.__name__ == "WeatherBenchModel":
-            fetch_list = fetch_list + ["total_loss"]
-            self.saver_loss = fetch_list[-1]
-            self.saver_loss_name = "Total loss"
-        else: 
-            self.saver_loss = "total_loss"
-        self.fetches = self.generate_fetches(fetch_list)
-        return self.fetches
 
     def create_fetches_for_val(self):
         """
@@ -492,6 +414,7 @@ class TrainModel(object):
         else:
             print("Total_loss:{}"
                   .format(results["total_loss"]))
+
     @staticmethod
     def plot_train(train_losses, val_losses, loss_name, output_dir):
         """
@@ -522,10 +445,10 @@ class TrainModel(object):
             json.dump(results_dict,fp) 
 
     @staticmethod
-    def save_results_to_pkl(train_losses,val_losses, output_dir):
+    def save_results_to_pkl(train_losses, val_losses, output_dir):
          with open(os.path.join(output_dir,"train_losses.pkl"),"wb") as f:
             pkl.dump(train_losses,f)
-         with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f:
+         with open(os.path.join(output_dir, "val_losses.pkl"),"wb") as f:
             pkl.dump(val_losses,f) 
 
     @staticmethod
@@ -587,16 +510,22 @@ class BestModelSelector(object):
         :return: Populated self.checkpoints_eval_all where the average of the metric over all forecast hours is listed
 
         """
-        method = BestModelSelector.run.__name__
         from main_visualize_postprocess import Postprocess
         metric_avg_all = []
 
         for checkpoint in self.checkpoints_all:
             print("Start to evalute checkpoint:", checkpoint)
             results_dir_eager = os.path.join(checkpoint, "results_eager")
-            eager_eval = Postprocess(results_dir=results_dir_eager, checkpoint=checkpoint, data_mode="val", batch_size=32,
-                                     seed=self.seed, eval_metrics=[eval_metric], channel=self.channel, frac_data=0.33,
-                                     lquick=True, ltest=self.ltest)
+            eager_eval = Postprocess(results_dir=results_dir_eager,
+                                     checkpoint=checkpoint,
+                                     data_mode="val",
+                                     batch_size=32,
+                                     seed=self.seed,
+                                     eval_metrics=[eval_metric],
+                                     channel=self.channel,
+                                     frac_data=0.33,
+                                     lquick=True,
+                                     ltest=self.ltest)
             eager_eval.run()
             eager_eval.handle_eval_metrics()
 
@@ -680,7 +609,6 @@ class BestModelSelector(object):
         if ncheckpoints == 0:
             raise FileExistsError("{0}: No checkpoint folders found under '{1}'".format(method, model_dir))
         else:
-            # glob.glob yiels unsorted directories, i.e. do the soring now
             checkpoints_all = sorted(checkpoints_all, key=lambda x: int(x.split("_")[-1].replace("/","")))
             print("%{0}: {1:d} checkpoints directories has been found.".format(method, ncheckpoints))
 
@@ -706,12 +634,12 @@ class BestModelSelector(object):
 
 def main():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--input_dir", type=str, required=True,
+    parser.add_argument("--input_dir",  type=str, required=True,
                         help="Directory where input data as TFRecord-files are stored.")
     parser.add_argument("--output_dir", help="Output directory where JSON-files, summary, model, plots etc. are saved.")
     parser.add_argument("--datasplit_dict", help="JSON-file that contains the datasplit configuration")
     parser.add_argument("--checkpoint", help="Checkpoint directory or checkpoint name (e.g. <my_dir>/model-200000)")
-    parser.add_argument("--dataset", type=str, help="Dataset name") # as in dataset_utils.DATASETS
+    parser.add_argument("--dataset", type=str, help="Dataset name")
     parser.add_argument("--model", type=str, help="Model class name")
     parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters")
     parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use")
@@ -724,18 +652,23 @@ def main():
                         help="Test mode for postprocessing to allow bootstrapping on small datasets.")
 
     args = parser.parse_args()
-    # start timing for the whole run
 
-    # list pip environment
     import os
     print(os.system("pip3 list"))
 
     timeit_start = time.time()
     # create a training instance
-    train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict,
-                 model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset_name=args.dataset,
-                 gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_start_save=args.frac_start_save,
-                 frac_intv_save=args.frac_intv_save)
+    train_case = TrainModel(input_dir=args.input_dir,
+                            output_dir=args.output_dir,
+                            datasplit_dict=args.datasplit_dict,
+                            model_hparams_dict=args.model_hparams_dict,
+                            model=args.model,
+                            checkpoint=args.checkpoint,
+                            dataset_name=args.dataset,
+                            gpu_mem_frac=args.gpu_mem_frac,
+                            seed=args.seed, args=args,
+                            frac_start_save=args.frac_start_save,
+                            frac_intv_save=args.frac_intv_save)
     
     print('----------------------------------- Options ------------------------------------')
     for k, v in args._get_kwargs():
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index fc55de930ca70eb8624ba85bbe9462f7735a59c9..34ed8da4970c0f30b0357e8d72ca128e601c009d 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -729,6 +729,7 @@ class Postprocess(TrainModel):
             gen_images_denorm = self.denorm_images_all_channels(
                 gen_images, self.vars_in, self.norm_cls, norm_method="minmax"
             )
+            # store data into datset & get number of samples (may differ from batch_size at the end of the test dataset)
             times_0, init_times = self.get_init_time(t_starts)
             batch_ds = self.create_dataset(
                 input_images_denorm, gen_images_denorm, init_times
@@ -814,7 +815,7 @@ class Postprocess(TrainModel):
         self.sess.run(tf.global_variables_initializer())
         self.sess.run(tf.local_variables_initializer())
 
-    def get_input_data_per_batch(self, input_iter, norm_method="cbnorm"):
+    def get_input_data_per_batch(self, input_iter, norm_method="minmax"):
         """
         Get the input sequence from the dataset iterator object stored in self.inputs and denormalize the data
         :param input_iter: the iterator object built by make_test_dataset_iterator-method
@@ -1297,7 +1298,8 @@ class Postprocess(TrainModel):
 
     @staticmethod
     def denorm_images_all_channels(
-        image_sequence, varnames, norm, norm_method="minmax")
+        image_sequence, varnames, norm, norm_method="minmax"
+    ):
         """
         Denormalize data of all image channels
         :param image_sequence: list/array [batch, seq, lat, lon, channel] of images
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py b/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py
index c2fdd4f1b20751768ce76c138124a215e45f0233..b758eca0efaa5740dcd6f852af4c36b8656962ab 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py
@@ -283,120 +283,13 @@ class Postprocess(TrainModel):
 
     # the run-factory
     def run(self):
-        if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet':
+        if self.model == "convLSTM" or self.model == "test_model":
             self.run_deterministic()
         elif self.run_mode == "deterministic":
             self.run_deterministic()
         else:
             self.run_stochastic()
 
-    def run_stochastic(self):
-        """
-        Run session, save results to netcdf, plot input images, generate images and persistent images
-        """
-        method = Postprocess.run_stochastic.__name__
-        raise ValueError("ML: %{0} is not runnable now".format(method))
-
-        self.init_session()
-        self.restore(self.sess, self.checkpoint)
-        # Loop for samples
-        self.sample_ind = 0
-        self.prst_metric_all = []  # store evaluation metrics of persistence forecast (shape [future_len])
-        self.fcst_metric_all = []  # store evaluation metric of stochastic forecasts (shape [nstoch, batch, future_len])
-        while self.sample_ind < self.num_samples_per_epoch:
-            if self.num_samples_per_epoch < self.sample_ind:
-                break
-            else:
-                # run the inputs and plot each sequence images
-                self.input_results, self.input_images_denorm_all, self.t_starts = self.get_input_data_per_batch()
-
-            feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()}
-            gen_loss_stochastic_batch = []  # [stochastic_ind,future_length]
-            gen_images_stochastic = []  # [stochastic_ind,batch_size,seq_len,lat,lon,channels]
-            # Loop for stochastics
-            for stochastic_sample_ind in range(self.num_stochastic_samples):
-                print("stochastic_sample_ind:", stochastic_sample_ind)
-                # return [batchsize,seq_len,lat,lon,channel]
-                gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
-                # The generate images seq_len should be sequence_len -1, since the last one is
-                # not used for comparing with groud truth
-                assert gen_images.shape[1] == self.sequence_length - 1
-                gen_images_per_batch = []
-                if stochastic_sample_ind == 0:
-                    persistent_images_per_batch = []  # [batch_size,seq_len,lat,lon,channel]
-                    ts_batch = []
-                for i in range(self.batch_size):
-                    # generate time stamps for sequences only once, since they are the same for all ensemble members
-                    if stochastic_sample_ind == 0:
-                        self.ts = Postprocess.generate_seq_timestamps(self.t_starts[i], len_seq=self.sequence_length)
-                        init_date_str = self.ts[0].strftime("%Y%m%d%H")
-                        ts_batch.append(init_date_str)
-                        # get persistence_images
-                        self.persistence_images, self.ts_persistence = Postprocess.get_persistence(self.ts,
-                                                                                                   self.input_dir_pkl)
-                        persistent_images_per_batch.append(self.persistence_images)
-                        assert len(np.array(persistent_images_per_batch).shape) == 5
-                        self.plot_persistence_images()
-
-                    # Denormalized data for generate
-                    gen_images_ = gen_images[i]
-                    self.gen_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl, gen_images_,
-                                                                                    self.vars_in)
-                    gen_images_per_batch.append(self.gen_images_denorm)
-                    assert len(np.array(gen_images_per_batch).shape) == 5
-                    # only plot when the first stochastic ind otherwise too many plots would be created
-                    # only plot the stochastic results of user-defined ind
-                    self.plot_generate_images(stochastic_sample_ind, self.stochastic_plot_id)
-                # calculate the persistnet error per batch
-                if stochastic_sample_ind == 0:
-                    persistent_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all,
-                                                                                       persistent_images_per_batch,
-                                                                                       self.future_length,
-                                                                                       self.context_frames,
-                                                                                       matric="mse", channel=0)
-                    self.prst_metric_all.append(persistent_loss_per_batch)
-
-                # calculate the gen_images_per_batch error
-                gen_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all,
-                                                                            gen_images_per_batch, self.future_length,
-                                                                            self.context_frames,
-                                                                            matric="mse", channel=0)
-                gen_loss_stochastic_batch.append(
-                    gen_loss_per_batch)  # self.gen_images_stochastic[stochastic,future_length]
-                print("gen_images_per_batch shape:", np.array(gen_images_per_batch).shape)
-                gen_images_stochastic.append(
-                    gen_images_per_batch)  # [stochastic,batch_size, seq_len, lat, lon, channel]
-
-                # Switch the 0 and 1 position
-                print("before transpose:", np.array(gen_images_stochastic).shape)
-            gen_images_stochastic = np.transpose(np.array(gen_images_stochastic), (
-                1, 0, 2, 3, 4, 5))  # [batch_size, stochastic, seq_len, lat, lon, chanel]
-            Postprocess.check_gen_images_stochastic_shape(gen_images_stochastic)
-            assert len(gen_images_stochastic.shape) == 6
-            assert np.array(gen_images_stochastic).shape[1] == self.num_stochastic_samples
-
-            self.fcst_metric_all.append(
-                gen_loss_stochastic_batch)  # [samples/batch_size,stochastic,future_length]
-            # save input and stochastic generate images to netcdf file
-            # For each prediction (either deterministic or ensemble) we create one netCDF file.
-            for batch_id in range(self.batch_size):
-                self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id],
-                                                                   persistent_images_per_batch[batch_id],
-                                                                   np.array(gen_images_stochastic)[batch_id],
-                                                                   fl_name="vfp_date_{}_sample_ind_{}.nc"
-                                                                   .format(ts_batch[batch_id],
-                                                                           self.sample_ind + batch_id))
-
-            self.sample_ind += self.batch_size
-
-        self.persistent_loss_all_batches = np.mean(np.array(self.persistent_loss_all_batches), axis=0)
-        self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0)
-        assert len(np.array(self.persistent_loss_all_batches).shape) == 1
-        assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length
-
-        assert len(np.array(self.stochastic_loss_all_batches).shape) == 2
-        assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples
-
     def run_deterministic(self):
         """
         Revised and vectorized version of run_deterministic
diff --git a/video_prediction_tools/model_modules/video_prediction/__init__.py b/video_prediction_tools/model_modules/video_prediction/__init__.py
index 3089b251bbbcbe086a03a4ea63571ce9427b2cb9..4ad018e9ea2d38ee103b49a1d2c847dfd2faa691 100644
--- a/video_prediction_tools/model_modules/video_prediction/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/__init__.py
@@ -1,3 +1,2 @@
 from . import losses
 from . import metrics
-from . import ops
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
index 3b7afcc929e4affcf6f7aa14da37808d1e1faf78..c789f64f8a69a9a5734b638eec43ec2bd250aba1 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py
@@ -1,7 +1,6 @@
 #from .base_dataset import BaseVideoDataset
 #from .era5_dataset import ERA5Dataset
 #from .gzprcp_dataset import GzprcpDataset
-#from .moving_mnist import MovingMnist
 #from data_preprocess.dataset_options import known_datasets
 from .stats import MinMax, ZScore
 from .dataset import Dataset
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py
index b42702a03842ba53ac3b0bcf8db14b7f83b98bc1..767e82bdd9d91bbdd883d5553e8d53906816a8b6 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py
@@ -3,16 +3,14 @@ __date__ = "2022-03-17"
 __email__ = "b.gong@fz-juelich.de"
 
 import json
-import os
-from typing import List
-from dataclasses import dataclass
 from pathlib import Path
-
 import xarray as xr
 import tensorflow as tf
-
+import sys, os
+sys.path.append("../../../utils")
 from hparams_utils import *
-from model_modules.video_prediction.datasets.stats import DatasetStats, Normalize
+sys.path.append("../../../")
+from model_modules.video_prediction.datasets.stats import DatasetStats, Normalize, ZScore
 
 
 class Dataset:
@@ -40,16 +38,16 @@ class Dataset:
     ):
         """
         This class is used for preparing data for training/validation and test models
-        :param input_dir: the path of tfrecords files
+        :param input_dir     : the path of tfrecords files
         :param datasplit_path: the path pointing to the datasplit_config json file
-        :param hparams_path: the path to the dict that contains hparameters,
-        :param mode: string, "train","val" or "test"
-        :param seed: int, the seed for shuffeling the dataset
-        :param nsamples_ref: number of reference samples which can be used to control repetition factor for dataset
+        :param hparams_path  : the path to the dict that contains hparameters,
+        :param mode          : string, "train","val" or "test"
+        :param seed          : int, the seed for shuffeling the dataset
+        :param nsamples_ref  : number of reference samples which can be used to control repetition factor for dataset
                              for ensuring adopted size of dataset iterator (used for validation data during training)
                              Example: Let nsamples_ref be 1000 while the current datset consists 100 samples, then
                                       the repetition-factor will be 10 (i.e. nsamples*rep_fac = nsamples_ref)
-        :param normalize: class of the desired normalization method
+        :param normalize     : class of the desired normalization method
         """
         self.input_dir = input_dir
         self.output_dir = output_dir
@@ -105,7 +103,7 @@ class Dataset:
         # {"2008":[1,2,3,4,...], "2009":[1,2,3,4,...]}
         for year, months in time_window.items():
             for month in months:
-                files.append(self.input_dir / self.filename_template.format(year=year, month=month))
+                files.append(os.path.join(self.input_dir, self.filename_template.format(year=year, month=month)))
         
         return files
 
@@ -116,6 +114,7 @@ class Dataset:
         :param mode: indicator to differentiate between training, validation and test data
         """
         files = self.filenames(mode)
+
         if not len(files) > 0:
             raise Exception(
                 f"no files for dataset {mode} found, check data_split dictionary"
@@ -167,7 +166,7 @@ class Dataset:
         # shuffle
         if shuffle:
             dataset = dataset.apply(
-                tf.contrib.data.shuffle_and_repeat(
+                tf.data.experimental.shuffle_and_repeat(
                     buffer_size=1024, count=self.max_epochs, seed=self.seed
                 )
             )  # TODO: check, self.seed
@@ -185,7 +184,7 @@ class Dataset:
             stats = self._stats_lookup[mode]
 
         normalize = self.normalize(stats)
-        stats.to_json(self.output_dir / "normalization_stats.json")
+        stats.to_json(os.path.join(self.output_dir, "normalization_stats.json"))
         
         dataset = dataset.map(normalize.normalize_vars)
 
@@ -208,7 +207,9 @@ class Dataset:
         """
         stats = self._get_stats("test")
         return int((stats.n + self.shift) / (self.sequence_length + self.shift))
-    
+
+
+
     @property
     def num_validation_samples(self):
         """
@@ -243,3 +244,17 @@ class Dataset:
     @property
     def test_stats(self):
         return self._get_stats("test")
+
+
+if __name__ == '__main__':
+    source_dir = "/Users/gongbing/PycharmProjects/weatherbench"
+    destination_dir = "/Users/gongbing/PycharmProjects/20221201T100105_gong1_wb_convlstm_gan"
+    datasplit_path = os.path.join(destination_dir, "data_split.json")
+    hparams_path = os.path.join("/Users/gongbing/PycharmProjects/20221201T100105_gong1_wb_convlstm_gan", "model_hparams.json")
+    normalize = ZScore
+    template="weatherbench_{year}-{month:02}.nc"
+    ins = Dataset(input_dir=source_dir,output_dir=destination_dir,datasplit_path=datasplit_path,
+            hparams_path=hparams_path,normalize=normalize, filename_template=template)
+    dt = ins.make_dataset(mode="train")
+    for i, data in enumerate(dt):
+        print(data)
\ No newline at end of file
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/stats.py b/video_prediction_tools/model_modules/video_prediction/datasets/stats.py
index a1934e6290712f715691fca02554bc012eda7dd6..6b5fb9b428eb9d8dec75d881245fc08d12dce9ad 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/stats.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/stats.py
@@ -1,11 +1,11 @@
+
 from abc import ABC, abstractmethod
 import dataclasses as dc
 import json
-from pathlib import Path
-
 import numpy as np
 import tensorflow as tf
 
+
 @dc.dataclass
 class VarStats:
     mean: float
diff --git a/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py
index 919b0d5ac64f4f9cc081d8aba681974870e91f3f..e25a4f12d5c5566a787c1fecaf7bb3651a67beed 100644
--- a/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py
+++ b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py
@@ -4,7 +4,81 @@
 
 import tensorflow as tf
 from .layer_def import *
+from tensorflow.compat.v1.nn.rnn_cell import LSTMStateTuple
 
+# class Zero_state(tf.Module):
+#     """Abstract object representing an Convolutional RNN cell.
+#     """
+#     def __init__(self, shape, num_features, name=None):
+#         super().__init__(name=name)
+#
+#         self.shape = shape
+#         self.num_features = num_features
+#
+#     def __call__(self, x):
+#         """Return zero-filled state tensor(s).
+#         Args:
+#           batch_size: int, float, or unit Tensor representing the batch size.
+#           dtype     : the data type to use for the state.
+#         Returns:
+#           tensor of shape '[batch_size x shape[0] x shape[1] x num_features]
+#           filled with zeros
+#         """
+#         num_features = self.num_features
+#         zeros = tf.zeros([tf.shape(x)[0], self.shape[0], self.shape[1], num_features * 2])
+#         return zeros
+#
+#
+# class BasicConvLSTMCell(tf.Module):
+#     """Basic ConvLSTM recurrent network cell. The
+#     """
+#     def __init__(self, shape, filter_size, num_features, forget_bias=1.0,
+#                  state_is_tuple=False, activation=tf.nn.tanh, name=None):
+#         super().__init__(name=name)
+#         """Initialize the basic Conv LSTM cell.
+#         Args:
+#           shape         : int tuple thats the height and width of the cell
+#           filter_size   : int tuple thats the height and width of the filter
+#           num_features  : int thats the depth of the cell
+#           forget_bias   : float, The bias added to forget gates (see above).
+#           input_size    : Deprecated and unused.
+#           state_is_tuple: If True, accepted and returned states are 2-tuples of
+#           the `c_state` and `m_state`.  If False, they are concatenated
+#           along the column axis.  The latter behavior will soon be deprecated.
+#           activation: Activation function of the inner states.
+#         """
+#
+#         self.shape = shape
+#         self.filter_size = filter_size
+#         self.num_features = num_features
+#         self._forget_bias = forget_bias
+#         self._state_is_tuple = state_is_tuple
+#         self._activation = activation
+#
+#         """Long short-term memory cell (LSTM)."""
+#
+#     def __call__(self, inputs, state):
+#         if self._state_is_tuple:
+#             c, h = state
+#         else:
+#             c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state)
+#
+#         input_h = [inputs, h]
+#
+#         input_h_con = tf.concat(axis = 3, values = input_h)
+#         concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4,
+#                             "decode_1", activate="sigmoid")
+#         i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat)
+#         new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
+#                      self._activation(j))
+#         new_h = self._activation(new_c) * tf.nn.sigmoid(o)
+#
+#         if self._state_is_tuple:
+#                 new_state = LSTMStateTuple(new_c, new_h)
+#         else:
+#             new_state = tf.concat(axis = 3, values = [new_c, new_h])
+#
+#         return new_h, new_state
 class ConvRNNCell(object):
     """Abstract object representing an Convolutional RNN cell.
     """
@@ -25,7 +99,7 @@ class ConvRNNCell(object):
         """Integer or TensorShape: size of outputs produced by this cell."""
         raise NotImplementedError("Abstract method")
 
-    def zero_state(self,input, dtype):
+    def zero_state(self, input):
         """Return zero-filled state tensor(s).
         Args:
           batch_size: int, float, or unit Tensor representing the batch size.
@@ -35,11 +109,9 @@ class ConvRNNCell(object):
           filled with zeros
         """
 
-        shape = self.shape
+
         num_features = self.num_features
-        #x= tf.placeholder(tf.float32, shape=[input.shape[0], shape[0], shape[1], num_features * 2])#Bing: add this to
-        zeros = tf.zeros([tf.shape(input)[0], shape[0], shape[1], num_features * 2])
-        #zeros = tf.zeros_like(x)
+        zeros = tf.zeros([input.shape[0], self.shape[0], self.shape[1], num_features * 2])
         return zeros
 
 
@@ -84,74 +156,26 @@ class BasicConvLSTMCell(ConvRNNCell):
 
     def __call__(self, inputs, state, scope=None,reuse=None):
         """Long short-term memory cell (LSTM)."""
-        with tf.variable_scope(scope or type(self).__name__,reuse=reuse):  # "BasicLSTMCell"
-            # Parameters of gates are concatenated into one multiply for efficiency.
-            if self._state_is_tuple:
-                c, h = state
-            else:
-                c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state)
-
-            input_h = [inputs,h]
-            #Bing20200930#replace with non-linear convolutional layers
-            #concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True)
-            input_h_con = tf.concat(axis = 3, values = input_h)
-            concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, "decode_1", activate="sigmoid")
-            i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat)
-            new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
-                     self._activation(j))
-            new_h = self._activation(new_c) * tf.nn.sigmoid(o)
-
-            if self._state_is_tuple:
-                new_state = LSTMStateTuple(new_c, new_h)
-            else:
-                new_state = tf.concat(axis = 3, values = [new_c, new_h])
-
-            return new_h, new_state
-
-
-def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None):
-    """convolution:
-    Args:
-      args: a 4D Tensor or a list of 4D, batch x n, Tensors.
-      filter_size: int tuple of filter height and width.
-      num_features: int, number of features.
-      bias_start: starting value to initialize the bias; 0 by default.
-      scope: VariableScope for the created subgraph; defaults to "Linear".
-    Returns:
-      A 4D Tensor with shape [batch h w num_features]
-    Raises:
-      ValueError: if some of the arguments has unspecified or wrong shape.
-    """
 
-    # Calculate the total size of arguments on dimension 1.
-    total_arg_size_depth = 0
-    shapes = [a.get_shape().as_list() for a in args]
-    for shape in shapes:
-        if len(shape) != 4:
-            raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes))
-        if not shape[3]:
-            raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes))
+        # Parameters of gates are concatenated into one multiply for efficiency.
+        if self._state_is_tuple:
+            c, h = state
         else:
-            total_arg_size_depth += shape[3]
+            c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state)
 
-    dtype = [a.dtype for a in args][0]
+        input_h = [inputs,h]
+        input_h_con = tf.concat(axis = 3, values = input_h)
+        concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, "decode_1", activate="sigmoid")
+        i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat)
+        new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
+                     self._activation(j))
+        new_h = self._activation(new_c) * tf.nn.sigmoid(o)
 
-    # Now the computation.
-    with tf.variable_scope(scope or "Conv"):
-        matrix = tf.get_variable(
-            "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype)
-        if len(args) == 1:
+        if self._state_is_tuple:
+            new_state = LSTMStateTuple(new_c, new_h)
+        else:
+            new_state = tf.concat(axis = 3, values = [new_c, new_h])
 
-            res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME')
+        return new_h, new_state
 
-        else:
-            res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME')
-        if not bias:
-            return res
-        bias_term = tf.get_variable(
-            "Bias", [num_features],
-            dtype = dtype,
-            initializer = tf.constant_initializer(
-                bias_start, dtype = dtype))
-    return res + bias_term
 
diff --git a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
index 6f7c4f38b222afecb2b21d36bedc938b7813399a..33b2a9f3f84f6a82f4997e34f53e27cf4f34d2a9 100644
--- a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
+++ b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
@@ -31,11 +31,11 @@ def _variable_on_gpu(name, shape, initializer):
       Variable Tensor
     """
     with tf.device('/gpu:0'):
-        var = tf.get_variable(name, shape, initializer=initializer)
+        var = tf.Variable(initializer(shape=shape))
     return var
 
 
-def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.layers.xavier_initializer()):
+def _variable_with_weight_decay(name, shape, stddev, wd):
     """Helper to create an initialized Variable with weight decay.
     Note that the Variable is initialized with a truncated normal distribution.
     A weight decay is added only if one is specified.
@@ -48,117 +48,113 @@ def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.l
     Returns:
       Variable Tensor
     """
-    #var = _variable_on_gpu(name, shape,tf.truncated_normal_initializer(stddev = stddev))
+    initializer = tf.initializers.GlorotUniform()
     var = _variable_on_gpu(name, shape, initializer)
     if wd:
         weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name = 'weight_loss')
         weight_decay.set_shape([])
-        tf.add_to_collection('losses', weight_decay)
+        tf.compat.v1.add_to_collection('losses', weight_decay)
     return var
 
 
-def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"):
-
-    with tf.variable_scope('{0}_conv'.format(idx)) as scope:
- 
-        input_channels = inputs.get_shape()[-1]
-        weights = _variable_with_weight_decay('weights', shape = [kernel_size, kernel_size,
-                                                                 input_channels, num_features],
-                                              stddev = 0.01, wd = weight_decay)
-        biases = _variable_on_gpu('biases', [num_features], initializer)
-        conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding='SAME')
-        conv_biased = tf.nn.bias_add(conv, biases)
-        if activate == "linear":
-            return conv_biased
-        elif activate == "relu":
-            conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx))  
-        elif activate == "elu":
-            conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx))   
-        elif activate == "leaky_relu":
-            conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx))
-        elif activate == "sigmoid":
-            conv_rect = tf.nn.sigmoid(conv_biased, name = '{0}_conv'.format(idx))
-        else:
-            raise ("activation function is not correct")
-        return conv_rect
-
-
-def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer(),activate="relu"):
-    with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope:
-        input_channels = inputs.get_shape()[3]
-        input_shape = inputs.get_shape().as_list()
-
-
-        weights = _variable_with_weight_decay('weights',
-                                              shape = [kernel_size, kernel_size, num_features, input_channels],
-                                              stddev = 0.1, wd = weight_decay)
-        biases = _variable_on_gpu('biases', [num_features],initializer)
-        batch_size = tf.shape(inputs)[0]
-
-        output_shape = tf.stack(
-            [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features])
-
-        conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME')
-        conv_biased = tf.nn.bias_add(conv, biases)
-        if activate == "linear":
-            return conv_biased
-        elif activate == "elu":
-            return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx))       
-        elif activate == "relu":
-            return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx))
-        elif activate == "leaky_relu":
-            return tf.nn.leaky_relu(conv_biased, name = '{0}_transpose_conv'.format(idx))
-        elif activate == "sigmoid":
-            return tf.nn.sigmoid(conv_biased, name ='sigmoid') 
-        else:
-            return conv_biased
+def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.initializers.GlorotUniform(), activate="relu"):
+
+    input_channels = inputs.get_shape()[-1]
+    weights = _variable_with_weight_decay('weights', shape = [kernel_size, kernel_size,
+                                                             input_channels, num_features],
+                                          stddev = 0.01, wd = weight_decay)
+    biases = _variable_on_gpu('biases', [num_features], initializer)
+    conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding='SAME')
+    conv_biased = tf.nn.bias_add(conv, biases)
+    if activate == "linear":
+        return conv_biased
+    elif activate == "relu":
+        conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx))
+    elif activate == "elu":
+        conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx))
+    elif activate == "leaky_relu":
+        conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx))
+    elif activate == "sigmoid":
+        conv_rect = tf.nn.sigmoid(conv_biased, name = '{0}_conv'.format(idx))
+    else:
+        raise ("activation function is not correct")
+    return conv_rect
+
+
+def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.initializers.GlorotUniform(),activate="relu"):
+
+    input_channels = inputs.get_shape()[3]
+    input_shape = inputs.get_shape().as_list()
+
+
+    weights = _variable_with_weight_decay('weights',
+                                          shape = [kernel_size, kernel_size, num_features, input_channels],
+                                          stddev = 0.1, wd = weight_decay)
+    biases = _variable_on_gpu('biases', [num_features],initializer)
+
+    output_shape = tf.stack(
+        [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features])
+
+    conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME')
+    conv_biased = tf.nn.bias_add(conv, biases)
+    if activate == "linear":
+        return conv_biased
+    elif activate == "elu":
+        return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx))
+    elif activate == "relu":
+        return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx))
+    elif activate == "leaky_relu":
+        return tf.nn.leaky_relu(conv_biased, name = '{0}_transpose_conv'.format(idx))
+    elif activate == "sigmoid":
+        return tf.nn.sigmoid(conv_biased, name ='sigmoid')
+    else:
+        return conv_biased
     
 
-def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,initializer=tf.contrib.layers.xavier_initializer()):
-    with tf.variable_scope('{0}_fc'.format(idx)) as scope:
-        input_shape = inputs.get_shape().as_list()
-        if flat:
-            dim = input_shape[1] * input_shape[2] * input_shape[3]
-            inputs_processed = tf.reshape(inputs, [-1, dim])
-        else:
-            dim = input_shape[1]
-            inputs_processed = inputs
-
-        weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init,
-                                              wd = weight_decay)
-        biases = _variable_on_gpu('biases', [hiddens],initializer)
-        if activate == "linear":
-            return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')
-        elif activate == "sigmoid":
-            return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
-        elif activate == "softmax":
-            return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
-        elif activate == "relu":
-            return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
-        elif activate == "leaky_relu":
-            return tf.nn.leaky_relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))        
-        else:
-            ip = tf.add(tf.matmul(inputs_processed, weights), biases)
-            return tf.nn.elu(ip, name = str(idx) + '_fc')
-
-        
-def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None):
-    with tf.variable_scope('{0}_bn'.format(idx)) as scope:
-        #Calculate batch mean and variance
-        shape = inputs.get_shape().as_list()
-        scale = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=is_training)
-        beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=is_training)
-        pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
-        pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)
-        
-        if is_training:
-            batch_mean, batch_var = tf.nn.moments(inputs,[0])
-            train_mean = tf.assign(pop_mean,pop_mean * decay + batch_mean * (1 - decay))
-            train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
-            with tf.control_dependencies([train_mean,train_var]):
-                 return tf.nn.batch_normalization(inputs,batch_mean,batch_var,beta,scale,epsilon)
-        else:
-             return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon)
+def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,initializer=tf.initializers.GlorotUniform()):
+
+    input_shape = inputs.get_shape().as_list()
+    if flat:
+        dim = input_shape[1] * input_shape[2] * input_shape[3]
+        inputs_processed = tf.reshape(inputs, [-1, dim])
+    else:
+        dim = input_shape[1]
+        inputs_processed = inputs
+
+    weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init,
+                                          wd = weight_decay)
+    biases = _variable_on_gpu('biases', [hiddens],initializer)
+    if activate == "linear":
+        return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')
+    elif activate == "sigmoid":
+        return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+    elif activate == "softmax":
+        return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+    elif activate == "relu":
+        return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+    elif activate == "leaky_relu":
+        return tf.nn.leaky_relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc'))
+    else:
+        ip = tf.add(tf.matmul(inputs_processed, weights), biases)
+        return tf.nn.elu(ip, name = str(idx) + '_fc')
+
+
+def bn_layers(inputs,is_training=True, epsilon=1e-3,decay=0.99):
+
+    shape = inputs.get_shape().as_list()
+    scale = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=is_training)
+    beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=is_training)
+    pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False)
+    pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False)
+
+    if is_training:
+        batch_mean, batch_var = tf.nn.moments(inputs,[0])
+        train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
+        train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
+        with tf.control_dependencies([train_mean,train_var]):
+                return tf.nn.batch_normalization(inputs,batch_mean,batch_var,beta,scale,epsilon)
+    else:
+        return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon)
 
 
 class batch_norm(object):
diff --git a/video_prediction_tools/model_modules/video_prediction/metrics.py b/video_prediction_tools/model_modules/video_prediction/metrics.py
index 570efd72d993808f143d632d6734eebffb9b9f7e..9e55bbc4683d5002ae6bb72e0c56d3cc91a1ff79 100644
--- a/video_prediction_tools/model_modules/video_prediction/metrics.py
+++ b/video_prediction_tools/model_modules/video_prediction/metrics.py
@@ -3,7 +3,6 @@
 # SPDX-License-Identifier: MIT
 
 import tensorflow as tf
-#import lpips_tf
 import math
 import numpy as np
 try:
@@ -15,8 +14,6 @@ except:
         print("Could not get ssmi-function from skimage. Please check installed skimage-package.")
         raise err
 
-
-
 def mse(a, b):
     return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1])
 
@@ -39,14 +36,6 @@ def mse_imgs(image1,image2):
     mse = ((image1 - image2)**2).mean(axis=None)
     return mse
 
-# def lpips(input0, input1):
-#     if input0.shape[-1].value == 1:
-#         input0 = tf.tile(input0, [1] * (input0.shape.ndims - 1) + [3])
-#     if input1.shape[-1].value == 1:
-#         input1 = tf.tile(input1, [1] * (input1.shape.ndims - 1) + [3])
-#
-#     distance = lpips_tf.lpips(input0, input1)
-#     return -distance
 
 def ssim_images(image1, image2):
     """
diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
index 5bad1430f9c7499965899fb948bd46bb8e7686c9..2a164529f4427876084256f424a483dd273ba154 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
@@ -34,7 +34,8 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
             self.learning_rate = self.hparams.lr
             self.sequence_length = self.hparams.sequence_length
             self.opt_var = self.hparams.opt_var
-            self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
+            self.predict_frames = set_and_check_pred_frames(self.sequence_length,
+                                                            self.context_frames)
             self.ngf = self.hparams.ngf
             self.ndf = self.hparams.ndf
 
@@ -43,38 +44,38 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
            raise ValueError("Method %{}: the hyper-parameter dictionary must include parameters above")
 
 
-    def build_graph(self, x: tf.Tensor):
+    @tf.function
+    def train_step(self, x: tf.Tensor, step):
 
-        self.inputs = x
+        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
+            x_hat = self.build_model(x)
 
-        self.global_step = tf.train.get_or_create_global_step()
-        original_global_variables = tf.global_variables()
+            self.total_loss = self.get_loss(x, x_hat)
 
-        #Build graph
-        x_hat = self.build_model(x)
+            generator_gradients = gen_tape.gradient(self.total_loss, self.gen_vars)
+            discriminator_gradients = disc_tape.gradient(self.D_loss, self.gen_vars)
 
-        #Get losses (reconstruciton loss, total loss and descriminator loss)
-        self.total_loss = self.get_loss(x, x_hat)
+            #Define optimizer
+            self.train_op = self.optimizer(self.total_loss)
 
-        #Define optimizer
-        self.train_op = self.optimizer(self.total_loss)
+            self.G_solver.apply_gradients(zip(generator_gradients,
+                                                    generator.trainable_variables))
+            self.D_solver.apply_gradients(zip(discriminator_gradients,
+                                                    discriminator.trainable_variables))
 
-        #Save to outputs
-        self.outputs["gen_images"] = x_hat
-        self.outputs["total_loss"] = self.total_loss
-        # Summary op
-        sum_dict = {"total_loss": self.total_loss,
-                  "D_loss": self.D_loss,
-                  "G_loss": self.G_loss,
-                  "D_loss_fake": self.D_loss_fake,
-                  "D_loss_real": self.D_loss_real,
-                  "recon_loss": self.recon_loss}
+            #Save to outputs
+            self.outputs["gen_images"] = x_hat
+            self.outputs["total_loss"] = self.total_loss
+            # Summary op
+            sum_dict = {"total_loss": self.total_loss,
+                      "D_loss": self.D_loss,
+                      "G_loss": self.G_loss,
+                      "D_loss_fake": self.D_loss_fake,
+                      "D_loss_real": self.D_loss_real,
+                      "recon_loss": self.recon_loss}
+
+            self.summary_op = self.summary(**sum_dict)
 
-        self.summary_op = self.summary(**sum_dict)
-        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-        self.saveable_variables = [self.global_step] + global_variables
-        self.is_build_graph = True
-        return self.is_build_graph
 
     def get_loss(self, x: tf.Tensor, x_hat: tf.Tensor):
         """
@@ -83,7 +84,6 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
         self.G_loss = self.get_gen_loss()
         self.D_loss = self.get_disc_loss()
         self._get_vars()
-        #self.recon_loss = self.get_loss(self, x, x_hat) #use the loss from vanilla convLSTM
 
         if self.opt_var == "all":
             x = x[:, self.context_frames:, :, :, :]
@@ -111,33 +111,14 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
         return total_loss
 
     def optimizer(self, *args):
+        pass
 
-        if self.mode == "train":
-            if self.recon_weight == 1:
-                print("Only train generator- ConvLSTM")
-                train_op = tf.train.AdamOptimizer(learning_rate =
-                                                       self.learning_rate).\
-                    minimize(self.total_loss, var_list=self.gen_vars)
-            else:
-                print("Training discriminator")
-                self.D_solver = tf.train.AdamOptimizer(learning_rate =self.learning_rate).\
-                    minimize(self.D_loss, var_list=self.disc_vars)
-                with tf.control_dependencies([self.D_solver]):
-                    print("Training generator....")
-                    self.G_solver = tf.train.AdamOptimizer(learning_rate =self.learning_rate).\
-                        minimize(self.total_loss, var_list=self.gen_vars)
-                with tf.control_dependencies([self.G_solver]):
-                    train_op = tf.assign_add(self.global_step, 1)
-        else:
-           train_op = None
-        return train_op
 
 
     def build_model(self, x):
         """
         Define gan architectures
         """
-        #conditional GAN
         x_hat = self.generator(x)
 
         self.D_real, self.D_real_logits = self.discriminator(self.inputs[:, self.context_frames:, :, :, 0:1])
@@ -163,21 +144,49 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
         return x_hat
 
 
+
+
+
+
+
+
+
+
+
     def discriminator(self, vid):
         """
         Function that get discriminator architecture
         """
-        with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE):
-            conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis1")
+        with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
+            conv1 = tf.layers.conv3d(vid, 64, kernel_size=[4,4,4],
+                                     strides=[2,2,2],
+                                     padding="SAME",
+                                     name="dis1")
             conv1 = self._lrelu(conv1)
-            conv2 = tf.layers.conv3d(conv1, 128, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis2")
+            conv2 = tf.layers.conv3d(conv1, 128,
+                                     kernel_size=[4, 4, 4],
+                                     strides=[2,2,2],
+                                     padding="SAME",
+                                     name="dis2")
             conv2 = self._lrelu(self.bd1(conv2))
-            conv3 = tf.layers.conv3d(conv2, 256, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME" ,name="dis3")
+            conv3 = tf.layers.conv3d(conv2, 256,
+                                     kernel_size=[4, 4, 4],
+                                     strides=[2,2,2],
+                                     padding="SAME",
+                                     name="dis3")
             conv3 = self._lrelu(self.bd2(conv3))
-            conv4 = tf.layers.conv3d(conv3, 512, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis4")
+            conv4 = tf.layers.conv3d(conv3, 512,
+                                     kernel_size=[4, 4, 4],
+                                     strides=[2, 2, 2],
+                                     padding="SAME",
+                                     name="dis4")
             conv4 = self._lrelu(self.bd3(conv4))
-            conv5 = tf.layers.conv3d(conv4, 1, kernel_size=[2,4,4],strides=[1,1,1],padding="SAME", name="dis5")
-            conv5 = tf.reshape(conv5, [-1,1])
+            conv5 = tf.layers.conv3d(conv4, 1,
+                                     kernel_size=[2, 4, 4],
+                                     strides=[1, 1, 1],
+                                     padding="SAME",
+                                     name="dis5")
+            conv5 = tf.reshape(conv5, [-1, 1])
             conv5sigmoid = tf.nn.sigmoid(conv5)
             return conv5sigmoid, conv5
 
@@ -187,13 +196,14 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
         """
         real_labels = tf.ones_like(self.D_real)
         gen_labels = tf.zeros_like(self.D_fake)
-        self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels))
-        self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=gen_labels))
+        self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
+            logits=self.D_real_logits, labels=real_labels))
+        self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
+            logits=self.D_fake_logits, labels=gen_labels))
         D_loss = self.D_loss_real + self.D_loss_fake
         return D_loss
 
 
-
     def get_gen_loss(self):
         """
         Param:
@@ -203,7 +213,7 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
         """
         real_labels = tf.ones_like(self.D_fake)
         G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits,
-                                                                             labels=real_labels))
+                                                                        labels=real_labels))
         return G_loss
    
     def _get_vars(self):
@@ -214,7 +224,6 @@ class ConvLstmGANVideoPredictionModel(BaseModels):
         self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
 
 
-
     def _lrelu(self, x, leak=0.2):
         return tf.maximum(x, leak * x)
 
diff --git a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
index 589a6b32b7fe9c80646e53310d54272f696cb88f..f5dbead33b1540acd9322a93d239f93873004654 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
@@ -5,7 +5,8 @@
 __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong"
 __date__ = "2022-04-13"
-
+import sys
+sys.path.append("../utils")
 from hparams_utils import *
 import json
 from abc import ABC, abstractmethod
@@ -33,10 +34,11 @@ class BaseModels(ABC):
         self.outputs = {}
         self.loss_summary = None
         self.summary_op = None
-        self.global_step = tf.train.get_or_create_global_step()
         self.saveable_variables = None
         self._is_build_graph_set = False
 
+
+
     
     def hparams_options(self, hparams_dict_config:str):
         if hparams_dict_config:
@@ -71,7 +73,7 @@ class BaseModels(ABC):
         return self.hparams
 
 
-    def build_graph(self, x: tf.Tensor)->bool:
+    def train_step(self, x: tf.Tensor)->bool:
         """
         This function is used for build the graph, and allow a optimiser to the graph by using tensorflow function.
 
@@ -90,17 +92,17 @@ class BaseModels(ABC):
                     return self._is_build_graph_set
 
         """
-        self.inputs = x
-        original_global_variables = tf.global_variables()
-        x_hat = self.build_model(x)
-        self.total_loss = self.get_loss(x, x_hat)
-        self.train_op = self.optimizer(self.total_loss)
-        self.outputs["gen_images"] = x_hat
-        self.summary_op = self.summary(total_loss = self.total_loss)
-        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-        self.saveable_variables = [self.global_step] + global_variables
-        self._is_build_graph_set = True
-        return self._is_build_graph_set
+        # self.inputs = x
+        # original_global_variables = tf.global_variables()
+        # x_hat = self.build_model(x)
+        # self.total_loss = self.get_loss(x, x_hat)
+        # self.train_op = self.optimizer(self.total_loss)
+        # self.outputs["gen_images"] = x_hat
+        # self.summary_op = self.summary(total_loss = self.total_loss)
+        # global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        # self.saveable_variables = [self.global_step] + global_variables
+        # self._is_build_graph_set = True
+
 
 
     def optimizer(self, total_loss):
@@ -141,14 +143,14 @@ class BaseModels(ABC):
 
 
 
-    @abstractmethod
-    def build_model(self, x)->tf.Tensor:
-        """
-        This function is used to create the network
-        Example: see example in vanilla_convLSTM_model.py, it must return prediction fnsrames and save it to the self.output
-        which is used for calculating the loss
-        """
-        pass
+    # @abstractmethod
+    # def build_model(self, x)->tf.Tensor:
+    #     """
+    #     This function is used to create the network
+    #     Example: see example in vanilla_convLSTM_model.py, it must return prediction fnsrames and save it to the self.output
+    #     which is used for calculating the loss
+    #     """
+    #     pass
 
 
 
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
index 70c050f53efa81f1c482fa06f3cd1a5318b6e987..4d6988c3c95768a3a7cb89716c1d17a48e6013a9 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
@@ -4,7 +4,7 @@
 
 __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong"
-__date__ = "2020-11-05"
+__date__ = "2022-01-05"
 
 from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames
 import tensorflow as tf
@@ -12,6 +12,7 @@ from model_modules.video_prediction.layers import layer_def as ld
 from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
 from .our_base_model import BaseModels
 from hparams_utils import *
+tf.config.experimental_run_functions_eagerly(True)
 
 class VanillaConvLstmVideoPredictionModel(BaseModels):
 
@@ -36,27 +37,47 @@ class VanillaConvLstmVideoPredictionModel(BaseModels):
             self.batch_size = hparams.batch_size
             self.shuffle_on_val = hparams.shuffle_on_val
             self.loss_fun = hparams.loss_fun
-            self.opt_var = hparams.opt_var
+            self.opt_var = "all"
+                #hparams.opt_var
             self.lr = hparams.lr
             self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
             print("The model hparams have been parsed successfully! ")
         except Exception as e:
-            raise ValueError(f"missing hyperparameter: {e.args[0]}")
+            raise ValueError(f"missing hyper-parameter: {e.args[0]}")
 
 
+    def compile(self, input_shape):
+        self.model = ConvLSTM_network(input_shape, self.sequence_length, self.context_frames)
+        self.optimizer = tf.keras.optimizers.Adam(self.lr)
+
+    @tf.function
+    def train_step(self, x, step):
+        print("x.shape", x.shape)
+
+        # Open a GradientTape to record the operations run
+        # during the forward pass, which enables auto-differentiation.
+        with tf.GradientTape() as tape:
+
+            # Run the forward pass of the layer.
+            # The operations that the layer applies
+            # to its inputs are going to be recorded
+            # on the GradientTape.
+
+            x_hat = self.model(x)
+
+            # Compute the loss value for this minibatch.
+            self.loss_value = self.get_loss(x, x_hat)
+
+        # Use the gradient tape to automatically retrieve
+        # the gradients of the trainable variables with respect to the loss.
+        grads = tape.gradient(self.loss_value, self.model.network_template.trainable_variables)
+
+        # Run one step of gradient descent by updating
+        # the value of the variables to minimize the loss.
+        self.optimizer.apply_gradients(zip(grads, self.model.network_template.trainable_variables))
+        print("optimiasec")
+        #print("after training:", self.optimizer.iterations.numpy())
 
-    def build_graph(self, x:tf.Tensor):
-        self.inputs = x
-        original_global_variables = tf.global_variables()
-        x_hat = self.build_model(x)
-        self.total_loss = self.get_loss(x, x_hat)
-        self.train_op = self.optimizer(self.total_loss)
-        self.outputs["gen_images"] = x_hat
-        self.summary_op = self.summary(total_loss = self.total_loss)
-        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
-        self.saveable_variables = [self.global_step] + global_variables
-        self._is_build_graph_set = True
-        return self._is_build_graph_set
 
 
     def get_loss(self, x:tf.Tensor, x_hat:tf.Tensor)->tf.Tensor:
@@ -65,18 +86,17 @@ class VanillaConvLstmVideoPredictionModel(BaseModels):
         :param x_hat: Prediction/output tensors
         :return     : the loss function
         """
-        #This is the loss function (MSE):
-        #Optimize all target variables/channels
+
         if self.opt_var == "all":
             x = x[:, self.context_frames:, :, :, :]
-            print("The model is optimzied on all the variables in the loss function")
         elif self.opt_var != "all" and isinstance(self.opt_var, str):
             self.opt_var = int(self.opt_var)
-            print("The model is optimized on the {} variable in the loss function".format(self.opt_var))
             x = x[:, self.context_frames:, :, :, self.opt_var]
             x_hat = x_hat[:, :, :, :, self.opt_var]
         else:
-            raise ValueError("The opt var in the hyperparameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables")
+            raise ValueError("The opt var in the hyperparameters setup should "
+                             "be '0','1','2' indicate the index of target variable "
+                             "to be optimised or 'all' indicating optimize all the variables")
 
         if self.loss_fun == "mse":
             total_loss = tf.reduce_mean(tf.square(x - x_hat))
@@ -86,72 +106,101 @@ class VanillaConvLstmVideoPredictionModel(BaseModels):
             bce = tf.keras.losses.BinaryCrossentropy()
             total_loss = bce(x_flatten, x_hat_predict_frames_flatten)
         else:
-            raise ValueError("Loss function is not selected properly, you should chose either 'mse' or 'cross_entropy'")
+            raise ValueError("Loss function is not selected properly, "
+                             "you should chose either 'mse' or 'cross_entropy'")
         return total_loss
 
 
-
-    def summary(self, total_loss)->None:
+    def summary(self, total_loss, step)->None:
         """
         return the summary operation can be used for TensorBoard
         """
-        tf.summary.scalar("total_loss", total_loss)
+        tf.summary.scalar("total_loss", total_loss, step)
         summary_op = tf.summary.merge_all()
         return summary_op
 
-    def build_model(self, x: tf.Tensor):
-        network_template = tf.make_template('network',
-                                            VanillaConvLstmVideoPredictionModel.convLSTM_cell)  # make the template to share the variables
-
-        x_hat = VanillaConvLstmVideoPredictionModel.convLSTM_network(x,
-                                                                     self.sequence_length,
-                                                                     self.context_frames,
-                                                                     network_template)
-        return x_hat
-
-
-    @staticmethod
-    def convLSTM_network(x:tf.Tensor, sequence_length:int, context_frames:int, network_template:tf.make_template)->tf.Tensor:
 
+    # @staticmethod
+    # def convLSTM_network(x:tf.Tensor,
+    #                      sequence_length:int,
+    #                      context_frames:int,
+    #                      network_template:tf.compat.v1.make_template)->tf.Tensor:
+    #
+    #     # create network
+    #     x_hat = []
+    #
+    #     # This is for training (optimization of convLSTM layer)
+    #     hidden_g = None
+    #     for i in range(sequence_length - 1):
+    #         if i < context_frames:
+    #             print("i",i)
+    #             x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g)
+    #         else:
+    #             x_1_g, hidden_g = network_template(x_1_g, hidden_g)
+    #         x_hat.append(x_1_g)
+    #
+    #     # pack them all together
+    #     x_hat = tf.stack(x_hat)
+    #     x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim
+    #     x_hat = x_hat[:, context_frames - 1:, :, :, :]
+    #     return x_hat
+
+
+
+class ConvLSTM_network(tf.Module):
+
+    def __init__(self, shape:list, sequence_length:int, context_frames:int, name=None):
+        super(ConvLSTM_network, self).__init__(name = name)
+
+        self.shape = shape
+        self.sequence_length = sequence_length
+        self.context_frames = context_frames
+        self.network_template = tf.compat.v1.make_template('network',
+                                                           ConvLSTM_network.convLSTM_cell)
+
+    @tf.Module.with_name_scope
+    def __call__(self, x):
         # create network
         x_hat = []
-
         # This is for training (optimization of convLSTM layer)
         hidden_g = None
-        for i in range(sequence_length - 1):
-            if i < context_frames:
-                x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g)
+        for i in range(self.sequence_length - 1):
+            if i < self.context_frames:
+
+                x_1_g, hidden_g = self.network_template(x[:, i, :, :, :], hidden_g)
             else:
-                x_1_g, hidden_g = network_template(x_1_g, hidden_g)
+                x_1_g, hidden_g = self.network_template(x_1_g, hidden_g)
             x_hat.append(x_1_g)
 
         # pack them all together
         x_hat = tf.stack(x_hat)
         x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4])  # change first dim with sec dim
-        x_hat = x_hat[:, context_frames - 1:, :, :, :]
+        x_hat = x_hat[:, self.context_frames - 1:, :, :, :]
         return x_hat
 
-
     @staticmethod
-
-    def convLSTM_cell(inputs:tf.Tensor, hidden:tf.Tensor):
+    def convLSTM_cell(inputs: tf.Tensor, hidden: tf.Tensor):
         """
-        SPDX-FileCopyrightText: loliverhennigh 
+        SPDX-FileCopyrightText: loliverhennigh
         SPDX-License-Identifier: Apache-2.0
-        The following function was revised based on the github https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow 
+        The following function was revised based on the github https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow
         """
         y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
 
         # conv lstm cell
         cell_shape = y_0.get_shape().as_list()
         channels = cell_shape[-1]
-        with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
-            cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64)
-            if hidden is None:
-                hidden = cell.zero_state(y_0, tf.float32)
-            output, hidden = cell(y_0, hidden)
+
+        cell = BasicConvLSTMCell(shape = [cell_shape[1],
+                                          cell_shape[2]],
+                                        filter_size = 5,
+                                        num_features = 64)
+        if hidden is None:
+            hidden = cell.zero_state(y_0)
+        output, hidden = cell(y_0, hidden)
         output_shape = output.get_shape().as_list()
         z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]])
-        #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction
+
         x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid")
         return x_hat, hidden
+
diff --git a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
index b2a1534e43012d56677259eba57bb60049bad6aa..6b1aeb5860e699c324793ffe7d91021c6d61fd06 100644
--- a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
+++ b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
@@ -1,15 +1,6 @@
-import itertools
-import os
-from collections import OrderedDict
 
-import numpy as np
-import six
+import os
 import tensorflow as tf
-import tensorflow.contrib.graph_editor as ge
-from tensorflow.core.framework import node_def_pb2
-from tensorflow.python.framework import device as pydev
-from tensorflow.python.training import device_setter
-from tensorflow.python.util import nest
 
 
 
diff --git a/video_prediction_tools/postprocess/plot_ambs_forecast.py b/video_prediction_tools/postprocess/plot_ambs_forecast.py
index 1ba2a7b2f4f3ffe0c5c591e4f0f8aecbcb614a35..5f0ce2fbe9645e2f8583ad0aaa9e783ac586d8be 100644
--- a/video_prediction_tools/postprocess/plot_ambs_forecast.py
+++ b/video_prediction_tools/postprocess/plot_ambs_forecast.py
@@ -21,7 +21,6 @@ import matplotlib
 
 matplotlib.use('Agg')
 import matplotlib.pyplot as plt
-import cartopy
 from mpl_toolkits.basemap import Basemap
 
 
@@ -102,11 +101,8 @@ def create_plot(data: xr.DataArray, data_ref: xr.DataArray, varname: str, fcst_h
         cbar_ax = fig.add_axes([0.95, pos.y0+0.08*(t*2-1), 0.02, pos.y1 - pos.y0])
         cbar = fig.colorbar(cs, cax=cbar_ax, orientation="vertical",  ticks=cbar_ticks)
         cbar.set_label(cbar_labs[t])
-    # save to disk
-    #plt.show()
-    #plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=-0.5, wspace=0.05)
+
     plt.subplots_adjust(hspace=-0.5)
-    #plt.tight_layout()
 
     plt.savefig(plt_fname, bbox_inches="tight")
     plt.close()
@@ -128,7 +124,7 @@ def main(args):
     fhh = args.forecast_hour
 
     if not os.path.isfile(filename):
-        raise FileNotFoundError("Could not find the indictaed netCDF-file '{0}'".format(filename))
+        raise FileNotFoundError("Could not find the indicated netCDF-file '{0}'".format(filename))
 
     with xr.open_dataset(filename) as dfile:
         t2m_fcst, t2m_ref = dfile["2t_fcst"], dfile["2t_ref"]