diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 92d33264d3dd831638a4afb33bc0fc1f18e41d79..08202678b896f43630ae2cecf88f34fe2fe67298 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -7,13 +7,14 @@ We took the code implementation from https://github.com/alexlee-gk/video_predict
 """
 
 __email__ = "b.gong@fz-juelich.de"
-__author__ = "Bing Gong"
+__author__ = "Bing Gong, Michael Langguth"
 __date__ = "2020-10-22"
 
 import argparse
 import errno
 import json
 import os
+from typing import Union, List
 import random
 import time
 import numpy as np
@@ -22,30 +23,30 @@ from model_modules.video_prediction import datasets, models
 import matplotlib.pyplot as plt
 import pickle as pkl
 from model_modules.video_prediction.utils import tf_utils
+from general_utils import *
 
 
 class TrainModel(object):
-    def __init__(self, input_dir=None, output_dir=None, datasplit_dict=None,
-                 model_hparams_dict=None, model=None, checkpoint=None, dataset=None,
-                 gpu_mem_frac=None, seed=None, args=None, save_diag_intv=20, save_model_intv = 1000):
-        
+    def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None,
+                 model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None,
+                 gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01):
+        """
+        Class instance for training the models
+        :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located
+        :param output_dir: directory where all the output is saved (e.g. model, JSON-files, training curves etc.)
+        :param datasplit_dict: JSON-file for defining data splitting
+        :param model_hparams_dict: JSON-file of model hyperparameters
+        :param model: model class name
+        :param checkpoint: checkpoint directory (pre-trained models)
+        :param dataset: dataset class name
+        :param gpu_mem_frac: fraction of GPU memory to be preallocated
+        :param seed: seed of the randomizers
+        :param args: list of arguments passed
+        :param diag_intv_frac: interval for diagnozing and saving model; the fraction with respect to the number of
+                               steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results
+                               into a diagnozing intreval of 10 iteration steps (= interval over which validation loss
+                               is averaged to identify best model performance)
         """
-        This class aims to train the models
-        args:
-            input_dir            : str, the path to the PreprocessData directory which is parent directory of "Pickle" and "tfrecords" files directiory. 
-            output_dir           : str, directory where json files, summary, model, gifs, etc are saved. "
-                                             "default is logs_dir/model_fname, where model_fname consists of "
-                                             "information from model and model_hparams
-            datasplit_dict       : str, the path pointing to the datasplit_config json file
-            model_hparams_dict   : str, a json file of model hyperparameters
-            checkpoint           : str, directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)
-            dataset              : str, dataset class name
-            model                : str, model class name
-            gpu_mem_frac         : float, fraction of gpu memory to use
-            save_diag_intv       : int, interval of iteration steps for which the loss is saved and for which a new
-                                   loss curve is plotted
-            save_model_intv      : int, interval of iteration for which the model is checkpointed
-        """ 
         self.input_dir = os.path.normpath(input_dir)
         self.output_dir = os.path.normpath(output_dir)
         self.datasplit_dict = datasplit_dict
@@ -56,9 +57,12 @@ class TrainModel(object):
         self.gpu_mem_frac = gpu_mem_frac
         self.seed = seed
         self.args = args
+        self.diag_intv_frac = diag_intv_frac
         # for diagnozing and saving the model during training
-        self.save_diag_intv = save_diag_intv
-        self.save_model_intv = save_model_intv
+        self.saver_loss = None         # set in create_fetches_for_train-method
+        self.saver_loss_name = None    # set in create_fetches_for_train-method 
+        self.saver_loss_dict = None    # set in create_fetches_for_train-method if loss of interest is nested 
+        self.diag_intv_step = None     # set in calculate_samples_and_epochs-method
 
     def setup(self):
         self.set_seed()
@@ -87,13 +91,14 @@ class TrainModel(object):
         """
         Checks if output directory is existing.
         """
+        method = TrainModel.check_output_dir.__name__
+
         if self.output_dir is None:
-            raise ValueError("Output_dir-argument is empty. Please define a proper output_dir")
+            raise ValueError("%{0}: Output_dir-argument is empty. Please define a proper output_dir".format(method))
         elif not os.path.isdir(self.output_dir):
-            raise NotADirectoryError("Base output_dir {0} does not exist. Please pass a proper output_dir and "+\
-                                     "make use of config_train.py.")
+            raise NotADirectoryError("Base output_dir {0} does not exist. Pass a proper output_dir".format(method) +
+                                     " and make use of env_setup/generate_runscript.py.")
 
-            
     def get_model_hparams_dict(self):
         """
         Get and read model_hparams_dict from json file to dictionary 
@@ -104,7 +109,6 @@ class TrainModel(object):
                 self.model_hparams_dict_load.update(json.loads(f.read()))
         return self.model_hparams_dict_load
 
-
     def load_params_from_checkpoints_dir(self):
         """
         If checkpoint is none, load and read the json files of datasplit_config, and hparam_config,
@@ -112,6 +116,8 @@ class TrainModel(object):
         If the checkpoint is given, the configuration of dataset, model and options in the checkpoint dir will be
         restored and used for continue training.
         """
+        method = TrainModel.load_params_from_checkpoints_dir.__name__
+
         if self.checkpoint:
             self.checkpoint_dir = os.path.normpath(self.checkpoint)
             if not os.path.isdir(self.checkpoint):
@@ -121,18 +127,18 @@ class TrainModel(object):
             # read and overwrite dataset and model from checkpoint
             try:
                 with open(os.path.join(self.checkpoint_dir, "options.json")) as f:
-                    print("loading options from checkpoint %s" % self.checkpoint)
+                    print("%{0}: Loading options from checkpoint '{1}'".format(method, self.checkpoint))
                     self.options = json.loads(f.read())
                     self.dataset = self.dataset or self.options['dataset']
                     self.model = self.model or self.options['model']
             except FileNotFoundError:
-                print("options.json was not loaded because it does not exist in {0}".format(self.checkpoint_dir))
+                print("%{0}: options.json does not exist in {1}".format(method, self.checkpoint_dir))
             # loading hyperparameters from checkpoint
             try:
                 with open(os.path.join(self.checkpoint_dir, "model_hparams.json")) as f:
                     self.model_hparams_dict_load.update(json.loads(f.read()))
             except FileNotFoundError:
-                print("model_hparams.json was not loaded because it does not exist in {0}".format(self.checkpoint_dir))
+                print("%{0}: model_hparams.json does not exist in {1}".format(method, self.checkpoint_dir))
                 
     def setup_dataset(self):
         """
@@ -144,7 +150,7 @@ class TrainModel(object):
                                           hparams_dict_config=self.model_hparams_dict)
         self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict,
                                         hparams_dict_config=self.model_hparams_dict)
-        # ML/BG 2021-06-15: Is the following needed?
+        # Retrieve sequence length from dataset
         self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
 
     def setup_model(self, mode="train"):
@@ -160,30 +166,27 @@ class TrainModel(object):
         build model graph
         """
         self.video_model.build_graph(self.inputs)
-
         
     def make_dataset_iterator(self):
         """
         Prepare the dataset interator for training and validation
         """
         self.batch_size = self.model_hparams_dict_load["batch_size"]
-        self.train_tf_dataset = self.train_dataset.make_dataset(self.batch_size)
-        self.train_iterator = self.train_tf_dataset.make_one_shot_iterator()
+        train_tf_dataset = self.train_dataset.make_dataset(self.batch_size)
+        train_iterator = train_tf_dataset.make_one_shot_iterator()
         # The `Iterator.string_handle()` method returns a tensor that can be evaluated
         # and used to feed the `handle` placeholder.
-        self.train_handle = self.train_iterator.string_handle()
-        self.val_tf_dataset = self.val_dataset.make_dataset(self.batch_size)
-        self.val_iterator = self.val_tf_dataset.make_one_shot_iterator()
-        self.val_handle = self.val_iterator.string_handle()
+        self.train_handle = train_iterator.string_handle()
+        val_tf_dataset = self.val_dataset.make_dataset(self.batch_size)
+        val_iterator = val_tf_dataset.make_one_shot_iterator()
+        self.val_handle = val_iterator.string_handle()
         self.iterator = tf.data.Iterator.from_string_handle(
-            self.train_handle, self.train_tf_dataset.output_types, self.train_tf_dataset.output_shapes)
+            self.train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
         self.inputs = self.iterator.get_next()
-        #since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model,
-        # otherwise the model will raise error
-        
+        # since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP
+        # Otherwise an error will be risen by SAVP 
         if self.dataset == "era5" and self.model == "savp":
-           del self.inputs["T_start"]
-
+            del self.inputs["T_start"]
 
     def save_dataset_model_params_to_checkpoint_dir(self, dataset, video_model):
         """
@@ -195,20 +198,18 @@ class TrainModel(object):
             f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4))
         with open(os.path.join(self.output_dir, "model_hparams.json"), "w") as f:
             f.write(json.dumps(video_model.hparams.values(), sort_keys=True, indent=4))
-        with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f:
-            f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4))
-
+        #with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f:
+        #   f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4))
 
     def count_parameters(self):
         """
         Count the paramteres of the model
         """ 
         with tf.name_scope("parameter_count"):
-            # exclude trainable variables that are replicas (used in multi-gpu setting)
+            # exclude trainable variables that are replicates (used in multi-gpu setting)
             self.trainable_variables = set(tf.trainable_variables()) & set(self.video_model.saveable_variables)
             self.parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in self.trainable_variables])
 
-
     def create_saver_and_writer(self):
         """
         Create saver to save the models latest checkpoints, and a summery writer to store the train/val metrics  
@@ -223,45 +224,49 @@ class TrainModel(object):
         self.gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.gpu_mem_frac, allow_growth=True)
         self.config = tf.ConfigProto(gpu_options=self.gpu_options, allow_soft_placement=True)
 
-
     def calculate_samples_and_epochs(self):
         """
         Calculate the number of samples for train dataset, which is used for each epoch training
         Calculate the iterations (samples multiple by max_epochs) for training.
         """
+        method = TrainModel.calculate_samples_and_epochs.__name__        
+
         batch_size = self.video_model.hparams.batch_size
-        max_epochs = self.video_model.hparams.max_epochs #the number of epochs
+        max_epochs = self.video_model.hparams.max_epochs # the number of epochs
         self.num_examples = self.train_dataset.num_examples_per_epoch()
         self.steps_per_epoch = int(self.num_examples/batch_size)
+        self.diag_intv_step = int(self.diag_intv_frac*self.steps_per_epoch)
         self.total_steps = self.steps_per_epoch * max_epochs
-        print("Batch size is {} ; max_epochs is {}; num_samples per epoch is {}; steps_per_epoch is {}, total steps is {}".format(batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps))
+        print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}"
+              .format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps))
 
     def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
         """
         Restore the models checkpoints if the checkpoints is given
         """
-  
+        method = TrainModel.restore.__name__
+
         if checkpoints is None:
-            print ("Checkpoint is empty!!")
-        elif os.path.isdir(checkpoints) and (not os.path.exists(os.path.join(checkpoints,"checkpoint"))):
-            print("There is not checkpoints in the dir {}".format(checkpoints))
+            print("%{0}: Checkpoint is None!".format(method))
+        elif os.path.isdir(checkpoints) and (not os.path.exists(os.path.join(checkpoints, "checkpoint"))):
+            print("%{0}: There are no checkpoints in the dir {1}".format(method, checkpoints))
         else:
-           var_list = self.video_model.saveable_variables
-           # possibly restore from multiple checkpoints. useful if subset of weights
-           # (e.g. generator or discriminator) are on different checkpoints.
-           if not isinstance(checkpoints, (list, tuple)):
-               checkpoints = [checkpoints]
-           # automatically skip global_step if more than one checkpoint is provided
-           skip_global_step = len(checkpoints) > 1
-           savers = []
-           for checkpoint in checkpoints:
-               print("creating restore saver from checkpoint %s" % checkpoint)
-               saver, _ = tf_utils.get_checkpoint_restore_saver(
-                   checkpoint, var_list, skip_global_step=skip_global_step,
-                   restore_to_checkpoint_mapping=restore_to_checkpoint_mapping)
-               savers.append(saver)
-           restore_op = [saver.saver_def.restore_op_name for saver in savers]
-           sess.run(restore_op)
+            var_list = self.video_model.saveable_variables
+            # possibly restore from multiple checkpoints. useful if subset of weights
+            # (e.g. generator or discriminator) are on different checkpoints.
+            if not isinstance(checkpoints, (list, tuple)):
+                checkpoints = [checkpoints]
+            # automatically skip global_step if more than one checkpoint is provided
+            skip_global_step = len(checkpoints) > 1
+            savers = []
+            for checkpoint in checkpoints:
+                print("%{0}: creating restore saver from checkpoint {1}".format(method, checkpoint))
+                saver, _ = tf_utils.get_checkpoint_restore_saver(checkpoint, var_list,
+                                                                 skip_global_step=skip_global_step,
+                                                                 restore_to_checkpoint_mapping=restore_to_checkpoint_mapping)
+                savers.append(saver)
+            restore_op = [saver.saver_def.restore_op_name for saver in savers]
+            sess.run(restore_op)
     
     def restore_train_val_losses(self):
         """
@@ -269,161 +274,216 @@ class TrainModel(object):
         """
         if self.checkpoint is None:
             train_losses, val_losses = [], []
-        elif os.path.isdir(self.checkpoint) and (not os.path.exists(os.path.join(self.output_dir,"checkpoint"))):
+        elif os.path.isdir(self.checkpoint) and (not os.path.exists(os.path.join(self.output_dir, "checkpoint"))):
             train_losses,val_losses = [], []
         else:
-            with open(os.path.join(self.output_dir,"train_losses.pkl"),"rb") as f:
+            with open(os.path.join(self.output_dir, "train_losses.pkl"), "rb") as f:
                 train_losses = pkl.load(f)
-            with open(os.path.join(self.output_dir,"val_losses.pkl"),"rb") as f:
+            with open(os.path.join(self.output_dir, "val_losses.pkl"), "rb") as f:
                 val_losses = pkl.load(f)
         return train_losses,val_losses
 
     def train_model(self):
         """
-        Start session and train the model
+        Start session and train the model by looping over all iteration steps
         """
+        method = TrainModel.train_model.__name__
+
         self.global_step = tf.train.get_or_create_global_step()
         with tf.Session(config=self.config) as sess:
             print("parameter_count =", sess.run(self.parameter_count))
             sess.run(tf.global_variables_initializer())
             sess.run(tf.local_variables_initializer())
             self.restore(sess, self.checkpoint)
-            #sess.graph.finalize()
-            self.start_step = sess.run(self.global_step)
-            print("start_step", self.start_step)
+            start_step = sess.run(self.global_step)
+            print("%{0}: Iteration starts at step {1}".format(method, start_step))
             # start at one step earlier to log everything without doing any training
             # step is relative to the start_step
             train_losses, val_losses = self.restore_train_val_losses()
+            # initialize auxiliary variables
             time_per_iteration = []
             run_start_time = time.time()
-            for step in range(self.start_step,self.total_steps):
+            val_loss_min = 999.
+            # perform iteration
+            for step in range(start_step, self.total_steps):
                 timeit_start = time.time()
-                #run for training dataset
+                # Run training data
                 self.create_fetches_for_train()             # In addition to the loss, we fetch the optimizer
                 self.results = sess.run(self.fetches)       # ...and run it here!
-                train_losses.append(self.results["total_loss"])
-                #Run and fetch losses for validation data
+                # Note: For SAVP, the obtained loss is a list where the first element is of interest, for convLSTM,
+                # it's just a number. Thus, with list(<losses>)[0], we can handle both
+                train_losses.append(list(self.results[self.saver_loss])[0])
+                # run and fetch losses for validation data
                 val_handle_eval = sess.run(self.val_handle)
                 self.create_fetches_for_val()
-                self.val_results = sess.run(self.val_fetches,feed_dict={self.train_handle: val_handle_eval})
-                val_losses.append(self.val_results["total_loss"])
+                self.val_results = sess.run(self.val_fetches, feed_dict={self.train_handle: val_handle_eval})
+                val_losses.append(list(self.val_results[self.saver_loss])[0])
                 self.write_to_summary()
-                self.print_results(step,self.results)
-                timeit_end = time.time()
-                time_per_iteration.append(timeit_end - timeit_start)
-                print("time needed for this step", timeit_end - timeit_start, ' s')
-                if step % self.save_model_intv == 0:
-                    self.saver.save(sess, os.path.join(self.output_dir, "model"), global_step=step)
-                if step % self.save_diag_intv == 0:
-                    # I save the pickle file and plot here inside the loop in case the training process cannot finished after job is done.
-                    TrainModel.save_results_to_pkl(train_losses,val_losses,self.output_dir)
-                    TrainModel.plot_train(train_losses,val_losses,step,self.output_dir)
-
-            #Totally train time over all the iterations
+                self.print_results(step, self.results)
+                # track iteration time
+                time_iter = time.time() - timeit_start
+                time_per_iteration.append(time_iter)
+                print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter))
+                if step > self.diag_intv_step and (step % self.diag_intv_step == 0 or step == self.total_steps - 1):
+                    lsave, val_loss_min = TrainModel.set_model_saver_flag(val_losses, val_loss_min, self.diag_intv_step)
+                    # save best and final model state
+                    if lsave or step == self.total_steps - 1:
+                        self.saver.save(sess, os.path.join(self.output_dir, "model_best" if lsave else "model_last"),
+                                        global_step=step)
+                    # pickle file and plots are always created
+                    TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir)
+                    TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir)
+
+            # Final diagnostics
+            # track time (save to pickle-files)
             train_time = time.time() - run_start_time
-            results_dict = {"train_time":train_time,
-                            "total_steps":self.total_steps}
+            results_dict = {"train_time": train_time,
+                            "total_steps": self.total_steps}
             TrainModel.save_results_to_dict(results_dict,self.output_dir)
-            print("train_losses:",train_losses)
-            print("val_losses:",val_losses) 
-            print("Done")
-            print("Total training time:", train_time/60., "min")
+            print("%{0}: Training loss decreased from {1:.6f} to {2:.6f}:"
+                  .format(method, np.mean(train_losses[0:10]), np.mean(train_losses[-self.diag_intv_step:])))
+            print("%{0}: Validation loss decreased from {1:.6f} to {2:.6f}:"
+                  .format(method, np.mean(val_losses[0:10]), np.mean(val_losses[-self.diag_intv_step:])))
+            print("%{0}: Training finsished".format(method))
+            print("%{0}: Total training time: {1:.2f} min".format(method, train_time/60.))
+
             return train_time, time_per_iteration
-            
  
     def create_fetches_for_train(self):
-       """
-       Fetch variables in the graph, this can be custermized based on models and also the needs of users
-       """
-       #This is the base fetch that for all the  models
-       self.fetches = {"train_op": self.video_model.train_op}         # fetching the optimizer!
-       self.fetches["summary"] = self.video_model.summary_op
-       self.fetches["global_step"] = self.global_step
-       if self.video_model.__class__.__name__ == "McNetVideoPredictionModel": self.fetches_for_train_mcnet()
-       if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": self.fetches_for_train_convLSTM()
-       if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.fetches_for_train_savp()
-       if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": self.fetches_for_train_vae()
-       if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":self.fetches_for_train_gan()
-       if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":self.fetches_for_train_convLSTM()
-       return self.fetches     
-    
-    def fetches_for_train_convLSTM(self):
         """
-        Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users
+        Fetch variables in the graph, this can be custermized based on models and also the needs of users
         """
-        self.fetches["total_loss"] = self.video_model.total_loss
-        self.fetches["inputs"] = self.video_model.inputs
+        # This is the basic fetch for all the models
+        fetch_list = ["train_op", "summary_op", "global_step"]
 
- 
-    def fetches_for_train_savp(self):
+        # Append fetches depending on model to be trained
+        if self.video_model.__class__.__name__ == "McNetVideoPredictionModel":
+            fetch_list = fetch_list + ["L_p", "L_gdl", "L_GAN"]
+            self.saver_loss = fetch_list[-3]  # ML: Is this a reasonable choice?
+            self.saver_loss_name = "Loss"
+        if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
+            fetch_list = fetch_list + ["inputs", "total_loss"]
+            self.saver_loss = fetch_list[-1]
+            self.saver_loss_name = "Total loss"
+        if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
+            fetch_list = fetch_list + ["g_losses", "d_losses", "d_loss", "g_loss", ("g_losses", "gen_l1_loss")]
+            # Add loss that is tracked
+            self.saver_loss = fetch_list[-1][1]                
+            self.saver_loss_dict = fetch_list[-1][0]
+            self.saver_loss_name = "Generator L1 loss"
+        if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
+            fetch_list = fetch_list + ["latent_loss", "recon_loss", "total_loss"]
+            self.saver_loss = fetch_list[-2]
+            self.saver_loss_name = "Reconstruction loss"
+        if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":
+            fetch_list = fetch_list + ["inputs", "total_loss"]
+            self.saver_loss = fetch_list[-1]
+            self.saver_loss_name = "Total loss"
+        if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":
+            fetch_list = fetch_list + ["inputs", "total_loss"]
+            self.saver_loss = fetch_list[-1]
+            self.saver_loss_name = "Total loss"
+
+        self.fetches = self.generate_fetches(fetch_list)
+
+        return self.fetches
+
+    def create_fetches_for_val(self):
         """
-        Fetch variables in the graph for savp model, this can be custermized based on models and the needs of users
+        Fetch variables in the graph for validation dataset, customized depending on models and users' needs
         """
-        self.fetches["g_losses"] = self.video_model.g_losses
-        self.fetches["d_losses"] = self.video_model.d_losses
-        self.fetches["d_loss"] = self.video_model.d_loss
-        self.fetches["g_loss"] = self.video_model.g_loss
-        self.fetches["total_loss"] = self.video_model.g_loss
-        self.fetches["inputs"] = self.video_model.inputs
+        method = TrainModel.create_fetches_for_val.__name__
 
+        if not self.saver_loss:
+            raise AttributeError("%{0}: saver_loss is still not set. create_fetches_for_train must be run in advance."
+                                 .format(method))
+        
+        if self.saver_loss_dict:
+            fetch_list = ["summary_op", (self.saver_loss_dict, self.saver_loss)]
+        else:
+            fetch_list= ["summary_op", self.saver_loss]
 
-    def fetches_for_train_mcnet(self):
-        """
-        Fetch variables in the graph for mcnet model, this can be custermized based on models and  the needs of users
-        """
-        self.fetches["L_p"] = self.video_model.L_p
-        self.fetches["L_gdl"] = self.video_model.L_gdl
-        self.fetches["L_GAN"]  = self.video_model.L_GAN        
+        self.val_fetches = self.generate_fetches(fetch_list)
 
-    def fetches_for_train_vae(self):
+        return self.val_fetches
+
+    def generate_fetches(self, fetch_list):
         """
-        Fetch variables in the graph for savp model, this can be custermized based on models and based on the needs of users
+        Generates dictionary of fetches from video model instance
+        :param fetch_list: list of attributes of video model instance that are of particular interest; 
+                           can also handle tuples as list-elements to get attributes nested in a dictionary
+        :return: dictionary of fetches with keys from fetch_list and values from video model instance
         """
-        self.fetches["latent_loss"] = self.video_model.latent_loss
-        self.fetches["recon_loss"] = self.video_model.recon_loss
-        self.fetches["total_loss"] = self.video_model.total_loss
+        method = TrainModel.generate_fetches.__name__
 
-    def fetches_for_train_gan(self):
-        self.fetches["total_loss"] = self.video_model.total_loss
+        if not self.video_model:
+            raise AttributeError("%{0}: video_model is still not set. setup_model must be run in advance."
+                                 .format(method))
 
-    def create_fetches_for_val(self):
-        """
-        Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users
-        """
-        if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
-            self.val_fetches = {"total_loss": self.video_model.g_loss}
-            self.val_fetches["inputs"] = self.video_model.inputs
-        else:
-            self.val_fetches = {"total_loss": self.video_model.total_loss}
-            self.val_fetches["inputs"] = self.video_model.inputs
-        self.val_fetches["summary"] = self.video_model.summary_op
+        fetches = {}
+        for fetch_req in fetch_list:
+            try:
+                if isinstance(fetch_req, tuple):
+                    fetches[fetch_req[1]] = getattr(self.video_model, fetch_req[0])[fetch_req[1]]
+                else:
+                    fetches[fetch_req] = getattr(self.video_model, fetch_req)
+            except Exception as err:
+                print("%{0}: Failed to retrieve {1} from video_model-attribute.".format(method, fetch_req))
+                raise err
+
+        return fetches
 
     def write_to_summary(self):
-        self.summary_writer.add_summary(self.results["summary"],self.results["global_step"])
-        self.summary_writer.add_summary(self.val_results["summary"],self.results["global_step"])
+        self.summary_writer.add_summary(self.results["summary_op"], self.results["global_step"])
+        self.summary_writer.add_summary(self.val_results["summary_op"], self.results["global_step"])
         self.summary_writer.flush()
 
-
-    def print_results(self,step,results):
+    def print_results(self, step, results):
         """
         Print the training results /validation results from the training step.
         """
+        method = TrainModel.print_results.__name__
+
         train_epoch = step/self.steps_per_epoch
-        print("progress  global step %d  epoch %0.1f" % (step + 1, train_epoch))
+        print("%{0}: Progress global step {1:d}  epoch {2:.1f}".format(method, step + 1, train_epoch))
         if self.video_model.__class__.__name__ == "McNetVideoPredictionModel":
-            print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"]))
+            print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"], results["L_p"],
+                                                                           results["L_gdl"],results["L_GAN"]))
         elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
             print ("Total_loss:{}".format(results["total_loss"]))
         elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
-            print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"]))
+            print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}"
+                  .format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"],results["gen_l1_loss"]))
         elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
             print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"]))
         else:
-            print ("The model name does not exist")
+            print("%{0}: Printing results of the model {1} is not implemented yet".format(method, self.video_model.__class__.__name__))
     
-            
     @staticmethod
-    def plot_train(train_losses,val_losses,step,output_dir):
+    def set_model_saver_flag(losses: List, old_min_loss: float, niter_steps: int = 100):
+        """
+        Sets flag to save the model given that a new minimum in the loss is readched
+        :param losses: list of losses over iteration steps
+        :param old_min_loss: previous loss
+        :param niter_steps: number of iteration steps over which the loss is averaged
+        :return flag: True if model should be saved
+        :return loss_avg: updated minimum loss
+        """
+        save_flag = False
+        if len(losses) <= niter_steps*2:
+            loss_avg = old_min_loss
+            return save_flag, loss_avg
+
+        loss_avg = np.mean(losses[-niter_steps:])
+        if loss_avg < old_min_loss:
+            save_flag = True
+        else:
+            loss_avg = old_min_loss
+
+        return save_flag, loss_avg
+
+    @staticmethod
+    def plot_train(train_losses, val_losses, loss_name, output_dir):
         """
         Function to plot training losses for train and val datasets against steps
         params:
@@ -441,7 +501,7 @@ class TrainModel(object):
         plt.yscale("log")
         plt.title('Training and Validation loss')
         plt.xlabel('Iterations')
-        plt.ylabel('Loss')
+        plt.ylabel(loss_name)
         plt.legend()
         plt.savefig(os.path.join(output_dir,'plot_train.png'))
         plt.close()
@@ -477,19 +537,17 @@ class TrainModel(object):
 
 def main():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
-                                                                     "train, val, test, etc, or a directory containing "
-                                                                     "the tfrecords")
-    parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
-                                             "default is logs_dir/model_fname, where model_fname consists of "
-                                             "information from model and model_hparams")
-    parser.add_argument("--datasplit_dict", help="json file that contains the datasplit configuration")
-    parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
-    parser.add_argument("--dataset", type=str, help="dataset class name")
-    parser.add_argument("--model", type=str, help="model class name")
-    parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
-    parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="fraction of gpu memory to use")
-    parser.add_argument("--seed",default=1234, type=int)
+    parser.add_argument("--input_dir", type=str, required=True,
+                        help="Directory where input data as TFRecord-files are stored.")
+    parser.add_argument("--output_dir", help="Output directory where JSON-files, summary, model, plots etc. are saved.")
+    parser.add_argument("--datasplit_dict", help="JSON-file that contains the datasplit configuration")
+    parser.add_argument("--checkpoint", help="Checkpoint directory or checkpoint name (e.g. <my_dir>/model-200000)")
+    parser.add_argument("--dataset", type=str, help="Dataset class name")
+    parser.add_argument("--model", type=str, help="Model class name")
+    parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters")
+    parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use")
+    parser.add_argument("--seed", default=1234, type=int)
+
     args = parser.parse_args()
     # start timing for the whole run
     timeit_start_total_time = time.time()  
diff --git a/video_prediction_tools/model_modules/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py
index 508f766d428d402c5119afbfd811f1eb73cbb5f7..41f4bb3d1eb59c7e51624c03bfb8a582cb38b1e2 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/base_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/base_model.py
@@ -69,6 +69,8 @@ class BaseVideoPredictionModel(object):
         self.accum_eval_metrics = None
         self.saveable_variables = None
         self.post_init_ops = None
+        # ML 2021-06-23: Do not hide global step in self.saveable_variables
+        self.global_step = None
 
     def get_default_hparams_dict(self):
         """
@@ -471,6 +473,8 @@ class VideoPredictionModel(BaseVideoPredictionModel):
         BaseVideoPredictionModel.build_graph(self, inputs)
 
         global_step = tf.train.get_or_create_global_step()
+        # ML 2021-06-23: Do not hide global step in self.saveable_variables
+        self.global_step = global_step
         # Capture the variables created from here until the train_op for the
         # saveable_variables. Note that if variables are being reused (e.g.
         # they were created by a previously built model), those variables won't
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
index 0c079c4c087d919490b1ade80f0bd73368f008c7..5c949c96a4f849d82143e80878bb59d8f6661965 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
@@ -15,7 +15,6 @@ class VanillaConvLstmVideoPredictionModel(object):
         """
         This is class for building convLSTM architecture by using updated hparameters
         args:
-             mode          :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
              hparams_dict : dict, the dictionary contains the hparaemters names and values
         """
         self.hparams_dict = hparams_dict
diff --git a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
index 7a1da880defb61dbd018c6f11ee14c34cf0ce43e..415275e8f909560cab725ecae38bb37a01809244 100644
--- a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
+++ b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py
@@ -526,10 +526,14 @@ def reduce_tensors(structures, shallow=False):
 
 def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None):
 
+    method = get_checkpoint_restore_saver.__name__
 
     if os.path.isdir(checkpoint):
         # latest_checkpoint doesn't work when the path has special characters
         checkpoint = tf.train.latest_checkpoint(checkpoint)
+    # print name of checkpoint-file for verbosity
+    print("%{0}: The follwoing checkpoint is used for restoring the model: '{1}'".format(method, checkpoint))
+    # Start processing the checkpoint
     checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint)
     checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys()
     restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0])
diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py
index 65a25c43e803abc4240edd2e6c4184420a5b0722..83ed526bac9c960229a30585be61606cbf527317 100644
--- a/video_prediction_tools/utils/general_utils.py
+++ b/video_prediction_tools/utils/general_utils.py
@@ -11,12 +11,15 @@ Provides:   * get_unique_vars
 """
 
 # import modules
+from typing import List, Union
 import os
 import numpy as np
-#import xarray as xr
 
+str_or_List = Union[List, str]
 # routines
-def get_unique_vars(varnames):
+
+
+def get_unique_vars(varnames: List[str]):
     """
     :param varnames: list of variable names (or any other list of strings)
     :return: list with unique elements of inputted varnames list
@@ -27,7 +30,7 @@ def get_unique_vars(varnames):
     return vars_uni, varsind, nvars_uni
 
 
-def add_str_to_path(path_in, add_str):
+def add_str_to_path(path_in: str, add_str: str):
     """
     :param path_in: input path which is a candidate to be extended by add_str (see below)
     :param add_str: String to be added to path_in if not already done
@@ -92,12 +95,15 @@ def isw(value, interval):
     return status
 
 
-def check_str_in_list(list_in, str2check, labort=True):
+def check_str_in_list(list_in: List, str2check: str_or_List, labort: bool = True, return_ind: bool = False):
     """
     Checks if all strings are found in list
     :param list_in: input list
     :param str2check: string or list of strings to be checked if they are part of list_in
-    :return: True if existence of all strings was confirmed
+    :param labort: Flag if error will be risen in case of missing string in list
+    :param return_ind: Flag if index for each string found in list will be returned
+    :return: True if existence of all strings was confirmed, if return_ind is True, the index of each string in list is
+             returned as well
     """
     method = check_str_in_list.__name__
 
@@ -112,20 +118,25 @@ def check_str_in_list(list_in, str2check, labort=True):
 
     stat_element = [True if str1 in list_in else False for str1 in str2check]
 
-    if not np.all(stat_element):
+    if np.all(stat_element):
+        stat = True
+    else:
         print("%{0}: The following elements are not part of the input list:".format(method))
-        inds_miss = np.where(stat_element)[0]
+        inds_miss = np.where(list(~np.array(stat_element)))[0]
         for i in inds_miss:
             print("* index {0:d}: {1}".format(i, str2check[i]))
         if labort:
             raise ValueError("%{0}: Could not find all expected strings in list.".format(method))
+    # return
+    if stat and not return_ind:
+        return stat
+    elif stat:
+        return stat, [list_in.index(str_curr) for str_curr in str2check]
     else:
-        stat = True
-    
-    return stat
+        return stat, []
 
 
-def check_dir(path2dir: str, lcreate=False):
+def check_dir(path2dir: str, lcreate: bool = False):
     """
     Checks if path2dir exists and create it if desired
     :param path2dir:
@@ -174,7 +185,31 @@ def reduce_dict(dict_in: dict, dict_ref: dict):
     return dict_reduced
 
 
-def provide_default(dict_in, keyname, default=None, required=False):
+def find_key(dict_in: dict, key: str):
+    """
+    Searchs through nested dictionaries for key.
+    :param dict_in: input dictionary (cas also be an OrderedDictionary)
+    :param key: key to be retrieved
+    :return: value of the key in dict_in
+    """
+    method = find_key.__name__
+    # sanity check
+    if not isinstance(dict_in, dict):
+        raise TypeError("%{0}: dict_in must be a dictionary instance, but is of type '{1}'"
+                        .format(method, type(dict_in)))
+    # check for key
+    if key in dict_in:
+        return dict_in[key]
+    for k, v in dict_in.items():
+        if isinstance(v,dict):
+            item = find_key(v, key)
+            if item is not None:
+                return item
+
+    raise ValueError("%{0}: {1} could not be found in dict_in".format(method, key))
+
+
+def provide_default(dict_in: dict, keyname: str, default=None, required: bool = False):
     """
     Returns values of key from input dictionary or alternatively its default
 
diff --git a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py
index d420ed1be8460ade5fae3302960de8b761f49f72..025a203b5bd7bc038239a95afea003938b811734 100755
--- a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py
+++ b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py
@@ -280,11 +280,11 @@ class Config_Preprocess1(Config_runscript_base):
         check_vars = [var.strip().lower() in known_vars for var in vars_list]
         status = all(check_vars)
         if not status:
-            inds_bad = [i for i, e in enumerate(check_vars) if e]  # np.where(~np.array(check_vars))[0]
+            inds_bad = [i for i, e in enumerate(check_vars) if not e]  # np.where(~np.array(check_vars))[0]
             if not silent:
                 print("%{0}: The following comma-separated elements are unknown variables:".format(method))
                 for ind in inds_bad:
-                    print(vars_list[ind])
+                    print("* {0}".format(vars_list[ind]))
             return status
 
         if not len(check_vars) >= 1: