diff --git a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh
index a6da4643636ab7997a72cfc81f9311de6d7e8527..80d4de5266bc57c944bd57ffa5359512b4f23a4b 100644
--- a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh
+++ b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh
@@ -23,9 +23,9 @@ VIRT_ENV_NAME="my_venv"
 
 # Activate virtual environment if needed (and possible)
 if [ -z ${VIRTUAL_ENV} ]; then
-   if [[ -f ../${VIRT_ENV_NAME}/bin/activate ]]; then
+   if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
       echo "Activating virtual environment..."
-      source ../${VIRT_ENV_NAME}/bin/activate
+      source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate
    else 
       echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
       exit 1
diff --git a/video_prediction_tools/data_preprocess/preprocess_data_step2.py b/video_prediction_tools/data_preprocess/preprocess_data_step2.py
index a062ee1cefc0b0683f59a4d86736a4500243761e..a197471b22f28cc1c3bae9fe29bd7279d2015cde 100644
--- a/video_prediction_tools/data_preprocess/preprocess_data_step2.py
+++ b/video_prediction_tools/data_preprocess/preprocess_data_step2.py
@@ -218,7 +218,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
             X_end = X_start + self.sequence_length
             seq = X_train[X_start:X_end, ...]
             # recording the start point of the timestamps (already datetime-objects)
-            t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0])
+           
+            t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start])
             seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars)))
             if not sequences:
                 last_start_sequence_iter = sequence_iter
diff --git a/video_prediction_tools/env_setup/install_venv_container.sh b/video_prediction_tools/env_setup/install_venv_container.sh
index f0d53cdb495d5df6dffd30b79a19f53e1a0b2e98..3e5c35b9c4d635179fafab47ee65e153b80d2380 100755
--- a/video_prediction_tools/env_setup/install_venv_container.sh
+++ b/video_prediction_tools/env_setup/install_venv_container.sh
@@ -67,13 +67,15 @@ echo "Actiavting virtual environment ${VENV_NAME} to install required Python mod
 ACT_VENV="${VENV_DIR}/bin/activate"
 source "${VENV_DIR}/bin/activate"
 # set PYTHONPATH...
-export PYTHONPATH="/usr/local/lib/python3.8/dist-packages/"
+export PYTHONPATH=/usr/local/lib/python3.8/dist-packages/:$PYTHONPATH
+export PYTHONPATH=${WORKING_DIR}/virtual_envs/${VENV_NAME}/lib/python3.8/site-packages:$PYTHONPATH
 export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH
 export PYTHONPATH=${WORKING_DIR}/utils:$PYTHONPATH
 export PYTHONPATH=${WORKING_DIR}/model_modules:$PYTHONPATH
 export PYTHONPATH=${WORKING_DIR}/postprocess:$PYTHONPATH
 # ... also ensure that PYTHONPATH is appended when activating the virtual environment...
-echo 'export PYTHONPATH="/usr/local/lib/python3.8/dist-packages/"' >> "${ACT_VENV}"
+echo 'export PYTHONPATH=/usr/local/lib/python3.8/dist-packages/:$PYTHONPATH' >> "${ACT_VENV}"
+echo 'export PYTHONPATH='${WORKING_DIR}'/virtual_envs/'${VENV_NAME}'/lib/python3.8/site-packages:$PYTHONPATH' >> ${ACT_VENV}
 echo 'export PYTHONPATH='${WORKING_DIR}':$PYTHONPATH' >> ${ACT_VENV}
 echo 'export PYTHONPATH='${WORKING_DIR}'/utils:$PYTHONPATH' >> ${ACT_VENV}
 echo 'export PYTHONPATH='${WORKING_DIR}'/model_modules:$PYTHONPATH' >> ${ACT_VENV}
diff --git a/video_prediction_tools/env_setup/modules_preprocess+extract.sh b/video_prediction_tools/env_setup/modules_preprocess+extract.sh
index c867554716e49f9fbe5c66275a158fefd505f927..7976201ab97cdc14b9ab3418e86898defc48fdf7 100755
--- a/video_prediction_tools/env_setup/modules_preprocess+extract.sh
+++ b/video_prediction_tools/env_setup/modules_preprocess+extract.sh
@@ -1,33 +1,32 @@
 #!/usr/bin/env bash
 
-# __author__ = Bing Gong, Michael Langguth
-# __date__  = '2020_06_26'
+# __author__ = Michael Langguth
+# __date__  = '2022_02_07'
 
-# This script loads the required modules for ambs on Juwels and HDF-ML.
-# Note that some other packages have to be installed into a venv (see create_env.sh and requirements.txt).
+# This script loads the required modules for AMBS on JSC's HPY_systems (HDF-ML, Juwels Cluster and Juwels Booster).
+# Further Python-packages may be installed in the virtual environment created by create_env.sh
+# (see also requirements.txt).
 
-HOST_NAME=`hostname`
+HOST_NAME=$(hostname)
 
 echo "Start loading modules on ${HOST_NAME} required for preprocessing..."
-echo "modules_preprocess.sh is subject to: "
+echo "modules_preprocess+extract.sh is used for: "
+echo "* data_extraction_era5.sh"
 echo "* preprocess_data_era5_step1.sh"
+echo "* generate_runscript.py"
 
 module purge
-module use $OTHERSTAGES
-ml Stages/2019a
-ml GCC/8.3.0
-ml ParaStationMPI/5.2.2-1
-ml mpi4py/3.0.1-Python-3.6.8
-# serialized version is not available on HFML
-# see https://gitlab.version.fz-juelich.de/haf/Wiki/-/wikis/HDF-ML%20System
-if [[ "${HOST_NAME}" == hdfml* ]]; then
-  ml h5py/2.9.0-serial-Python-3.6.8
-elif [[ "${HOST_NAME}" == juwels* ]]; then
-  ml h5py/2.9.0-Python-3.6.8
-fi
-ml SciPy-Stack/2019a-Python-3.6.8
-ml scikit/2019a-Python-3.6.8
-ml netcdf4-python/1.5.0.1-Python-3.6.8
+module use "$OTHERSTAGES"
+ml Stages/2020
+ml GCC/10.3.0
+ml GCCcore/.10.3.0
+ml ParaStationMPI/5.4.10-1
+ml mpi4py/3.0.3-Python-3.8.5
+ml h5py/2.10.0-Python-3.8.5
+ml netcdf4-python/1.5.4-Python-3.8.5
+ml SciPy-Stack/2021-Python-3.8.5
+ml scikit/2021-Python-3.8.5
+ml CDO/2.0.0rc3
 
 # clean up if triggered via script argument
 if [[ $1 == purge ]]; then
diff --git a/video_prediction_tools/env_setup/requirements.txt b/video_prediction_tools/env_setup/requirements.txt
index 9f433734a966541c6c6a20a6387a499716b2d80a..28b7c6f83865095745ccab685b08c60aba8a71f9 100755
--- a/video_prediction_tools/env_setup/requirements.txt
+++ b/video_prediction_tools/env_setup/requirements.txt
@@ -3,5 +3,7 @@ mpi4py==3.0.1
 pandas==0.25.3
 xarray==0.16.0
 basemap==1.3.0
+numpy==1.17.3     # although this numpy-version is in the container, we set it here to avoid any further installation
 scikit-image==0.18.1
 opencv-python-headless==4.2.0.34
+netcdf4
diff --git a/video_prediction_tools/env_setup/wrapper_container.sh b/video_prediction_tools/env_setup/wrapper_container.sh
index fea29a0a9018a5436122389164cfff0859f22552..cfe716bee9f610b4a44988fc2ff6e4be048d06b4 100755
--- a/video_prediction_tools/env_setup/wrapper_container.sh
+++ b/video_prediction_tools/env_setup/wrapper_container.sh
@@ -27,6 +27,8 @@ export PYTHONPATH=/usr/local/lib/python3.8/dist-packages:$PYTHONPATH
 # ... and modules from this project
 export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH
 export PYTHONPATH=${WORKING_DIR}/utils:$PYTHONPATH
+export PYTHONPATH=${WORKING_DIR}/model_modules:$PYTHONPATH
+export PYTHONPATH=${WORKING_DIR}/postprocess:$PYTHONPATH
 
 # Control
 echo "****** Check PYTHONPATH *****"
diff --git a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json
index 2275e60f543badb1367351a50938e7bcacf2f119..f36e1c0b44279ad2e4f9e741c7bfade0a5aa0a05 100644
--- a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json
+++ b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json
@@ -1,5 +1,5 @@
 {
-    "batch_size": 4,
+    "batch_size": 32,
     "lr": 0.0002,
     "beta1": 0.5,
     "beta2": 0.999,
@@ -12,9 +12,11 @@
     "gan_feature_cdist_weight": 0.0,
     "state_weight": 0.0,
     "nz": 16,
-    "max_epochs":2,
+    "max_epochs":4,
     "context_frames": 12,
-    "opt_var": "0"
+    "opt_var": "0",
+    "decay_steps":[3000,9000],
+    "end_lr": 0.00000008
 }
 
 
diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index 7ccddc88c66128bcab07104a818a9ff73faa3316..9e58de96a31913eb19678e151fac5c46d6e80409 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -11,14 +11,15 @@ __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong, Michael Langguth"
 __date__ = "2020-10-22"
 
+import os, glob
 import argparse
 import errno
 import json
-import os
 from typing import Union, List
 import random
 import time
 import numpy as np
+import xarray as xr
 import tensorflow as tf
 from model_modules.video_prediction import datasets, models
 import matplotlib.pyplot as plt
@@ -26,12 +27,13 @@ import pickle as pkl
 from model_modules.video_prediction.utils import tf_utils
 from general_utils import *
 import math
-
+import shutil
 
 class TrainModel(object):
     def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None,
                  model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None,
-                 gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float=None, prob_save_model:float=None):
+                 gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.001,
+                 frac_start_save: float = None, frac_intv_save: float = None):
         """
         Class instance for training the models
         :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located
@@ -44,12 +46,9 @@ class TrainModel(object):
         :param gpu_mem_frac: fraction of GPU memory to be preallocated
         :param seed: seed of the randomizers
         :param args: list of arguments passed
-        :param diag_intv_frac: interval for diagnozing and saving model; the fraction with respect to the number of
-                               steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results
-                               into a diagnozing intreval of 10 iteration steps (= interval over which validation loss
-                               is averaged to identify best model performance)
-        :param frac_save_model_start: fraction of total iterations steps as the start point to save checkpoints
-        :param prob_save_model: probabability that model are saved to checkpoint (control the frequences of saving model0)
+        :param diag_intv_frac: interval for diagnozing the model (create loss-curves and save pickle-file with losses)
+        :param frac_start_save: fraction of total iterations steps to start checkpointing the model
+        :param frac_intv_save: fraction of total iterations steps for checkpointing the model
         """
         self.input_dir = os.path.normpath(input_dir)
         self.output_dir = os.path.normpath(output_dir)
@@ -62,8 +61,8 @@ class TrainModel(object):
         self.seed = seed
         self.args = args
         self.diag_intv_frac = diag_intv_frac
-        self.frac_save_model_start = frac_save_model_start
-        self.prob_save_model = prob_save_model
+        self.frac_start_save = frac_start_save
+        self.frac_intv_save = frac_intv_save
         # for diagnozing and saving the model during training
         self.saver_loss = None         # set in create_fetches_for_train-method
         self.saver_loss_name = None    # set in create_fetches_for_train-method 
@@ -74,18 +73,16 @@ class TrainModel(object):
         self.set_seed()
         self.get_model_hparams_dict()
         self.load_params_from_checkpoints_dir()
-        self.setup_dataset()
-        self.setup_model()
+        self.setup_datasets()
         self.make_dataset_iterator()
+        self.setup_model()
         self.setup_graph()
         self.save_dataset_model_params_to_checkpoint_dir(dataset=self.train_dataset,video_model=self.video_model)
         self.count_parameters()
         self.create_saver_and_writer()
         self.setup_gpu_config()
-        self.calculate_samples_and_epochs()
         self.calculate_checkpoint_saver_conf()
 
-
     def set_seed(self):
         """
         Set seed to control the same train/val/testing dataset for the same seed
@@ -148,16 +145,23 @@ class TrainModel(object):
             except FileNotFoundError:
                 print("%{0}: model_hparams.json does not exist in {1}".format(method, self.checkpoint_dir))
                 
-    def setup_dataset(self):
+    def setup_datasets(self):
         """
         Setup train and val dataset instance with the corresponding data split configuration.
         Simultaneously, sequence_length is attached to the hyperparameter dictionary.
         """
+        # get some parameters from the model hyperparameters
+        self.batch_size = self.model_hparams_dict_load["batch_size"]
+        self.max_epochs = self.model_hparams_dict_load["max_epochs"]
+        # create dataset instance
         VideoDataset = datasets.get_dataset_class(self.dataset)
         self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict,
                                           hparams_dict_config=self.model_hparams_dict)
+        self.calculate_samples_and_epochs()
+        self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
+        # set-up validation dataset and calculate number of batches for calculating validation loss
         self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict,
-                                        hparams_dict_config=self.model_hparams_dict)
+                                        hparams_dict_config=self.model_hparams_dict, nsamples_ref=self.num_examples)
         # Retrieve sequence length from dataset
         self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
 
@@ -239,15 +243,17 @@ class TrainModel(object):
         """
         method = TrainModel.calculate_samples_and_epochs.__name__        
 
-        batch_size = self.video_model.hparams.batch_size
-        max_epochs = self.video_model.hparams.max_epochs # the number of epochs
         self.num_examples = self.train_dataset.num_examples_per_epoch()
-        self.steps_per_epoch = int(self.num_examples/batch_size)
-        self.diag_intv_step = int(self.diag_intv_frac*self.steps_per_epoch)
-        self.total_steps = self.steps_per_epoch * max_epochs
+        self.steps_per_epoch = int(self.num_examples/self.batch_size)
+        self.total_steps = self.steps_per_epoch * self.max_epochs
+        self.diag_intv_step = int(self.diag_intv_frac*self.total_steps)
+        if self.diag_intv_step == 0:
+            self.diag_intv_step = 1
+        else:
+            pass
         print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}"
-              .format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps))
-
+              .format(method, self.batch_size, self.max_epochs, self.num_examples, self.steps_per_epoch,
+                      self.total_steps))
 
     def calculate_checkpoint_saver_conf(self):
         """
@@ -256,17 +262,18 @@ class TrainModel(object):
         """
         method = TrainModel.calculate_checkpoint_saver_conf.__name__
 
-        if hasattr(self.total_steps, "attr_name"):
-            raise SyntaxError(" function 'calculate_sample_and_epochs' is required to call to calcualte the total_step before all function {}".format(method))
-        if self.prob_save_model > 1 or self.prob_save_model<0 :
-            raise ValueError("pro_save_model should be less than 1 and larger than 0")
-        if self.frac_save_model_start > 1 or self.frac_save_model_start<0:
-            raise ValueError("frac_save_model_start should be less than 1 and larger than 0")
-
-        self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start))
-        self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model))
-        print("The model will be saved starting from step {} with {} interval step ".format(str(self.start_checkpoint_step),self.saver_interval_step))
+        if not hasattr(self, "total_steps"):
+            raise RuntimeError("%{0} self.total_steps is still unset. Run calculate_samples_and_epochs beforehand"
+                               .format(method))
+        if self.frac_intv_save > 1 or self.frac_intv_save<0 :
+            raise ValueError("%{0}: frac_intv_save must be less than 1 and larger than 0".format(method))
+        if self.frac_start_save > 1 or self.frac_start_save < 0:
+            raise ValueError("%{0}: frac_start_save must be less than 1 and larger than 0".format(method))
 
+        self.chp_start_step = int(math.ceil(self.total_steps * self.frac_start_save))
+        self.chp_intv_step = int(math.ceil(self.total_steps * self.frac_intv_save))
+        print("%{0}: Model will be saved after step {1:d} at each {2:d} interval step "
+              .format(method, self.chp_start_step,self.chp_intv_step))
 
     def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
         """
@@ -313,20 +320,14 @@ class TrainModel(object):
 
     def create_checkpoints_folder(self, step:int=None):
         """
-        Create a folder to store checkpoint at certain step
-        :param step: the step you want to save the checkpoint
+        Create a folder to store checkpoint at certain step.
+        :param step: the iteration step corresponding to the checkpoint
         return : dir path to save model
         """
-        dir_name = "checkpoint_" + str(step)
-        full_dir_name = os.path.join(self.output_dir,dir_name)
-        if os.path.isfile(os.path.join(full_dir_name,"checkpoints")):
-            print("The checkpoint at step {} exists".format(step))
-        else:
-            os.mkdir(full_dir_name)
+        full_dir_name = os.path.join(self.output_dir, "checkpoint_{0:d}".format(step))
+        os.makedirs(full_dir_name, exist_ok=True)
         return full_dir_name
 
-
-
     def train_model(self):
         """
         Start session and train the model by looping over all iteration steps
@@ -335,7 +336,6 @@ class TrainModel(object):
 
         self.global_step = tf.train.get_or_create_global_step()
         with tf.Session(config=self.config) as sess:
-            print("parameter_count =", sess.run(self.parameter_count))
             sess.run(tf.global_variables_initializer())
             sess.run(tf.local_variables_initializer())
             self.restore(sess, self.checkpoint)
@@ -347,7 +347,6 @@ class TrainModel(object):
             # initialize auxiliary variables
             time_per_iteration = []
             run_start_time = time.time()
-
             # perform iteration
             for step in range(start_step, self.total_steps):
                 timeit_start = time.time()
@@ -368,22 +367,22 @@ class TrainModel(object):
                 time_iter = time.time() - timeit_start
                 time_per_iteration.append(time_iter)
                 print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter))
-
-                if step > self.start_checkpoint_step and (step % self.saver_interval_step == 0 or step == self.total_steps - 1):
+                if (step >= self.chp_start_step and (step-self.chp_start_step)%self.chp_intv_step == 0) or \
+                    step == self.total_steps - 1:
                     #create a checkpoint folder for step
                     full_dir_name = self.create_checkpoints_folder(step=step)
                     self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step)
 
                 # pickle file and plots are always created
-                TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir)
-                TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir)
+                if step % self.diag_intv_step == 0 or step == self.total_steps - 1:
+                    TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir)
+                    TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir)
 
-            # Final diagnostics
-            # track time (save to pickle-files)
+            # Final diagnostics: training track time and save to pickle-files)
             train_time = time.time() - run_start_time
-            results_dict = {"train_time": train_time,
-                            "total_steps": self.total_steps}
-            TrainModel.save_results_to_dict(results_dict,self.output_dir)
+            results_dict = {"train_time": train_time, "total_steps": self.total_steps}
+            TrainModel.save_results_to_dict(results_dict, self.output_dir)
+
             print("%{0}: Training loss decreased from {1:.6f} to {2:.6f}:"
                   .format(method, np.mean(train_losses[0:10]), np.mean(train_losses[-self.diag_intv_step:])))
             print("%{0}: Validation loss decreased from {1:.6f} to {2:.6f}:"
@@ -441,7 +440,6 @@ class TrainModel(object):
         if not self.saver_loss:
             raise AttributeError("%{0}: saver_loss is still not set. create_fetches_for_train must be run in advance."
                                  .format(method))
-        
         if self.saver_loss_dict:
             fetch_list = ["summary_op", (self.saver_loss_dict, self.saver_loss)]
         else:
@@ -497,39 +495,13 @@ class TrainModel(object):
             print ("Total_loss:{}".format(results["total_loss"]))
         elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
             print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}"
-                  .format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"],results["gen_l1_loss"]))
+                  .format(results["g_losses"], results["d_losses"], results["g_loss"], results["d_loss"],
+                          results["gen_l1_loss"]))
         elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
-            print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"]))
+            print("Total_loss:{}; latent_losses:{}; reconst_loss:{}"
+                  .format(results["total_loss"], results["latent_loss"], results["recon_loss"]))
         else:
-            print("%{0}: Printing results of the model {1} is not implemented yet".format(method, self.video_model.__class__.__name__))
-    
-    @staticmethod
-    def set_model_saver_flag(losses: List, old_min_loss: float, niter_steps: int = 100):
-        """
-        Sets flag to save the model given that a new minimum in the loss is readched
-        :param losses: list of losses over iteration steps
-        :param old_min_loss: previous loss
-        :param niter_steps: number of iteration steps over which the loss is averaged
-        :return flag: True if model should be saved
-        :return loss_avg: updated minimum loss
-        """
-        method = TrainModel.set_model_saver_flag.__name__       
- 
-        save_flag = False
-        if len(losses) <= niter_steps*2:
-            loss_avg = old_min_loss
-            return save_flag, loss_avg
-
-        loss_avg = np.mean(losses[-niter_steps:])
-        # print diagnosis
-        print("%{0}: Current loss: {1:.4f}, old minimum: {2:.4f}, model will be saved: {3}"
-              .format(method, loss_avg, old_min_loss, loss_avg < old_min_loss))
-        if loss_avg < old_min_loss:
-            save_flag = True
-        else:
-            loss_avg = old_min_loss
-
-        return save_flag, loss_avg
+            print("%{0}: Printing results of model '{1}' is not implemented yet".format(method, self.video_model.__class__.__name__))
 
     @staticmethod
     def plot_train(train_losses, val_losses, loss_name, output_dir):
@@ -584,6 +556,162 @@ class TrainModel(object):
             pkl.dump(loss_per_iteration_val,f)
 
 
+
+class BestModelSelector(object):
+    """
+    Class to select the best performing model from multiple checkpoints created during training
+    """
+        
+    def __init__(self, model_dir: str, eval_metric: str, criterion: str = "min", channel: int = 0, seed: int = 42):
+        """
+        Class to retrieve the best model checkpoint. The last one is also retained.
+        :param model_dir: path to directory where checkpoints are saved (the trained model output directory)
+        :param eval_metric: evaluation metric for model selection (must be implemented in Scores)
+        :param criterion: set to 'min' ('max') for negatively (positively) oriented metrics
+        :param channel: channel of data used for selection
+        :param seed: seed for the Postprocess-instance
+        """
+        method = self.__class__.__name__
+        # sanity check
+        if not os.path.isdir(model_dir):
+            raise NotADirectoryError("{0}: The passed directory '{1}' does not exist".format(method, model_dir))
+        assert criterion in ["min", "max"], "%{0}: criterion must be either 'min' or 'max'.".format(method)
+        # set class attributes
+        self.seed = seed
+        self.channel = channel
+        self.metric = eval_metric
+        self.checkpoint_base_dir = model_dir
+        self.checkpoints_all = BestModelSelector.get_checkpoints_dirs(model_dir)
+        self.ncheckpoints = len(self.checkpoints_all)
+        # evaluate all checkpoints...
+        self.checkpoints_eval_all = self.run(self.metric)
+        # ... and finalize by choosing the best model and cleaning up
+        _ = self.finalize(criterion)
+
+    def run(self, eval_metric):
+        """
+        Runs eager postprocessing on all checkpoints with evaluation of chosen metric
+        :param eval_metric: the target evaluation metric
+        :return: Populated self.checkpoints_eval_all where the average of the metric over all forecast hours is listed
+
+        """
+        method = BestModelSelector.run.__name__
+        from main_visualize_postprocess import Postprocess
+        metric_avg_all = []
+
+        for checkpoint in self.checkpoints_all:
+            print("Start to evalute checkpoint:", checkpoint)
+            results_dir_eager = os.path.join(checkpoint, "results_eager")
+            eager_eval = Postprocess(results_dir=results_dir_eager, checkpoint=checkpoint, data_mode="val", batch_size=32,
+                                     seed=self.seed, eval_metrics=[eval_metric], channel=self.channel, frac_data=0.33,
+                                     lquick=True)
+            eager_eval.run()
+            eager_eval.handle_eval_metrics()
+
+            eval_metric_ds = eager_eval.eval_metrics_ds
+
+            metric_avg_all.append(BestModelSelector.get_avg_var(eval_metric_ds, "avg"))
+            print("Checkpoint {} is evaluated".format(checkpoint))
+
+        return metric_avg_all
+
+    def finalize(self, criterion):
+        """
+        Choose the best performing model checkpoint and delete all checkpoints apart from the best and the final ones
+        :return: status if everything runs
+        """
+        method = BestModelSelector.finalize.__name__
+
+        best_ind = self.get_best_checkpoint(criterion)
+        if best_ind == self.ncheckpoints -1:
+            print("%{0}: Last model checkpoint performs best ({1}: {2:.5f}) and is retained exclusively."
+                  .format(method, self.metric, self.checkpoints_eval_all[-1]))
+        else:
+            print("%{0}: The last ({1}: {2:.5f}) and the best ({1}: {3:.5f}) model checkpoint are retained."
+                  .format(method, self.metric, self.checkpoints_eval_all[-1], self.checkpoints_eval_all[best_ind]))
+
+        stat = self.clean_checkpoints(best_ind)
+        return stat
+
+    def get_best_checkpoint(self, criterion: str):
+        """
+        Choose the best performing model checkpoint
+        :param criterion: "max" or "min"
+        :return: index of best checkpoint in terms of evaluation metric
+        """
+        method = BestModelSelector.get_best_checkpoint.__name__
+
+        if not self.checkpoints_eval_all:
+            raise AttributeError("%{0}: checkpoints_eval_all is still empty. run-method must be executed beforehand."
+                                 .format(method))
+
+        if criterion == "min":
+            best_index = np.argmin(self.checkpoints_eval_all)
+        else:
+            best_index = np.argmax(self.checkpoints_eval_all)
+
+        return best_index
+
+    def clean_checkpoints(self, best_ind: int):
+        """
+        Delete all checkpoints apart from the best and the final ones
+        :param best_ind: index of best performing checkpoint
+        :return: status
+        """
+        method = BestModelSelector.clean_checkpoints.__name__
+
+        # list of checkpoints to keep (while ensuring uniqueness!)
+        checkpoints_keep = list({self.checkpoints_all[best_ind], self.checkpoints_all[-1]})
+        print("%{0}: The following checkpoints are retained: \n * {1}".format(method, "\n* ".join(checkpoints_keep)))
+        # drop checkpoints of interest from removal-list
+        checkpoints_op = self.checkpoints_all.copy()
+        for keep in checkpoints_keep:
+            checkpoints_op.remove(keep)
+
+        for dir_path in checkpoints_op:
+            shutil.rmtree(dir_path)
+            print("%{0}: The checkpoint directory {1} was removed.".format(method, dir_path))
+
+        return True
+
+    @staticmethod
+    def get_checkpoints_dirs(model_dir):
+        """
+        Function to obtain all checkpoint directories in a list.
+        :param model_dir: path to directory where checkpoints are saved (the trained model output directory)
+        :return: list of all checkpoint directories in model_dir
+        """
+        method = BestModelSelector.get_checkpoints_dirs.__name__
+
+        checkpoints_all = glob.glob(os.path.join(model_dir, "checkpoint*/"))
+        ncheckpoints = len(checkpoints_all)
+        if ncheckpoints == 0:
+            raise FileExistsError("{0}: No checkpoint folders found under '{1}'".format(method, model_dir))
+        else:
+            # glob.glob yiels unsorted directories, i.e. do the soring now
+            checkpoints_all = sorted(checkpoints_all, key=lambda x: int(x.split("_")[-1].replace("/","")))
+            print("%{0}: {1:d} checkpoints directories has been found.".format(method, ncheckpoints))
+
+        return checkpoints_all
+
+    @staticmethod
+    def get_avg_var(ds: xr.Dataset, varname_substr: str):
+        """
+        Retrieves and averages variable from dataset
+        :param ds: the dataset
+        :param varname_substr: the name of the variable or a substring suifficient to retrieve the variable
+        :return: the averaged variable
+        """
+        varnames = list(ds.variables)
+        var_in_file = [s for s in varnames if varname_substr in s]
+        try:
+            var_mean = ds[var_in_file[0]].mean().values
+        except Exception as err:
+            raise err
+
+        return var_mean
+
+
 def main():
     parser = argparse.ArgumentParser()
     parser.add_argument("--input_dir", type=str, required=True,
@@ -595,18 +723,20 @@ def main():
     parser.add_argument("--model", type=str, help="Model class name")
     parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters")
     parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use")
-    parser.add_argument("--frac_save_model_start", type=float,default=0.6,help="fraction of the start step for saving checkpoint")
-    parser.add_argument("--prob_save_model", type = float, default = 0.01, help = "probabability that model are saved to checkpoint (control the frequences of saving model")
+    parser.add_argument("--frac_start_save", type=float, default=1.,
+                        help="Fraction of all iteration steps after which checkpointing starts.")
+    parser.add_argument("--frac_intv_save", type=float, default=0.01,
+                        help="Fraction of all iteration steps to define the saving interval.")
     parser.add_argument("--seed", default=1234, type=int)
 
     args = parser.parse_args()
     # start timing for the whole run
-    timeit_start_total_time = time.time()  
-    #create a training instance
+    timeit_start = time.time()
+    # create a training instance
     train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict,
                  model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset,
-                 gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_save_model_start=args.frac_save_model_start,
-                 prob_save_model=args.prob_save_model)
+                 gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_start_save=args.frac_start_save,
+                 frac_intv_save=args.frac_intv_save)
     
     print('----------------------------------- Options ------------------------------------')
     for k, v in args._get_kwargs():
@@ -618,9 +748,18 @@ def main():
  
     # train model
     train_time, time_per_iteration = train_case.train_model()
-       
-    total_run_time = time.time() - timeit_start_total_time
-    train_case.save_timing_to_pkl(total_run_time, train_time, time_per_iteration, args.output_dir)
-    
+    timeit_after_train = time.time()
+    train_case.save_timing_to_pkl(timeit_after_train - timeit_start, train_time, time_per_iteration, args.output_dir)
+
+    # select best model
+    if args.dataset == "era5" and args.frac_start_save < 1.:
+        _ = BestModelSelector(args.output_dir, "mse")
+        timeit_finish = time.time()
+        print("Selecting the best model checkpoint took {0:.2f} minutes.".format((timeit_finish - timeit_after_train)/60.))
+    else:
+        timeit_finish = time.time()
+    print("Total time elapsed {0} minutes.".format((timeit_finish - timeit_start)/60.))
+
+
 if __name__ == '__main__':
     main()
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index 65df7e4abc3991cf6ae6d81987a46608611fa911..f95cfd79d9439a3009a0ca60c29fe57559024b00 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -26,7 +26,7 @@ from normalization import Norm_data
 from netcdf_datahandling import get_era5_varatts
 from general_utils import check_dir
 from metadata import MetaData as MetaData
-from main_scripts.main_train_models import *
+from main_train_models import TrainModel
 from data_preprocess.preprocess_data_step2 import *
 from model_modules.video_prediction import datasets, models, metrics
 from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, calculate_cond_quantiles, Scores
@@ -34,31 +34,32 @@ from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, crea
 
 
 class Postprocess(TrainModel):
-    def __init__(self, results_dir: str = None, checkpoint: str = None, mode: str = "test", batch_size: int = None,
-                 num_stochastic_samples: int = 1, stochastic_plot_id: int = 0, gpu_mem_frac: float = None,
-                 seed: int = None, channel: int = 0, args=None, run_mode: str = "deterministic",
-                 eval_metrics: List = ("mse", "psnr", "ssim", "acc"),
-                 clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly",
-                 lquick: bool = None):
+    def __init__(self, results_dir: str = None, checkpoint: str = None, data_mode: str = "test", batch_size: int = None,
+                 gpu_mem_frac: float = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0,
+                 seed: int = None, channel: int = 0, run_mode: str = "deterministic", lquick: bool = None,
+                 frac_data: float = 1., eval_metrics: List = ("mse", "psnr", "ssim", "acc"), args=None,
+                 clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly"):
         """
         Initialization of the class instance for postprocessing (generation of forecasts from trained model +
         basic evauation).
         :param results_dir: output directory to save results
         :param checkpoint: directory point to the model checkpoints
-        :param mode: mode of dataset to be processed ("train", "val" or "test"), default: "test"
+        :param data_mode: mode of dataset to be processed ("train", "val" or "test"), default: "test"
         :param batch_size: mini-batch size for generating forecasts from trained model
+        :param gpu_mem_frac: fraction of GPU memory to be preallocated
         :param num_stochastic_samples: number of ensemble members for variational models (SAVP, VAE), default: 1
                                        not supported yet!!!
         :param stochastic_plot_id: not supported yet!
-        :param gpu_mem_frac: fraction of GPU memory to be pre-allocated
         :param seed: Integer controlling randomization
         :param channel: Channel of interest for statistical evaluation
-        :param args: namespace of parsed arguments
         :param run_mode: "deterministic" or "stochastic", default: "deterministic", "stochastic is not supported yet!!!
+        :param lquick: flag for quick evaluation
+        :param frac_data: fraction of dataset to be used for evaluation (only applied when shuffling is active)
         :param eval_metrics: metrics used to evaluate the trained model
         :param clim_path:  the path to the netCDF-file storing climatolgical data
-        :param lquick: flag for quick evaluation
+        :param args: namespace of parsed arguments
         """
+        tf.reset_default_graph()
         # copy over attributes from parsed argument
         self.results_dir = self.output_dir = os.path.normpath(results_dir)
         _ = check_dir(self.results_dir, lcreate=True)
@@ -75,9 +76,10 @@ class Postprocess(TrainModel):
             self.checkpoint += "/"          # trick to handle checkpoint-directory and file simulataneously
         self.clim_path = clim_path
         self.run_mode = run_mode
-        self.mode = mode
+        self.data_mode = data_mode
         self.channel = channel
         self.lquick = lquick
+        self.frac_data = frac_data
         # Attributes set during runtime
         self.norm_cls = None
         # configuration of basic evaluation
@@ -85,8 +87,9 @@ class Postprocess(TrainModel):
         self.nboots_block = 1000
         self.block_length = 7 * 24  # this corresponds to a block length of 7 days in case of hourly forecasts
         # initialize evrything to get an executable Postprocess instance
-        self.save_args_to_option_json()     # create options.json-in results directory
-        self.copy_data_model_json()         # copy over JSON-files from model directory
+        if args is not None:
+            self.save_args_to_option_json()     # create options.json in results directory
+        self.copy_data_model_json()             # copy over JSON-files from model directory
         # get some parameters related to model and dataset
         self.datasplit_dict, self.model_hparams_dict, self.dataset, self.model, self.input_dir_tfr = self.load_jsons()
         self.model_hparams_dict_load = self.get_model_hparams_dict()
@@ -104,18 +107,19 @@ class Postprocess(TrainModel):
         self.stat_fl = self.set_stat_file()
         self.cond_quantile_vars = self.init_cond_quantile_vars()
         # setup test dataset and model
-        self.test_dataset, self.num_samples_per_epoch = self.setup_test_dataset()
+        self.test_dataset, self.num_samples_per_epoch = self.setup_dataset()
+        if lquick and self.test_dataset.shuffled:
+            self.num_samples_per_epoch = Postprocess.reduce_samples(self.num_samples_per_epoch, frac_data)
         # self.num_samples_per_epoch = 100              # reduced number of epoch samples -> useful for testing
         self.sequence_length, self.context_frames, self.future_length = self.get_data_params()
         self.inputs, self.input_ts = self.make_test_dataset_iterator()
+        self.data_clim = None
+        if "acc" in eval_metrics:
+            self.load_climdata()
         # set-up model, its graph and do GPU-configuration (from TrainModel)
-        self.setup_model(mode=self.mode)
+        self.setup_model(mode="test")
         self.setup_graph()
         self.setup_gpu_config()
-        if "acc" in eval_metrics:
-            self.load_climdata()
-        else:
-            self.data_clim = None
 
     # Methods that are called during initialization
     def get_input_dirs(self):
@@ -153,11 +157,11 @@ class Postprocess(TrainModel):
         method_name = Postprocess.copy_data_model_json.__name__
 
         # correctness of self.checkpoint and self.results_dir is already checked in __init__
-        checkpoint_dir = os.path.dirname(self.checkpoint)
-        model_opt_js = os.path.join(checkpoint_dir, "options.json")
-        model_ds_js = os.path.join(checkpoint_dir, "dataset_hparams.json")
-        model_hp_js = os.path.join(checkpoint_dir, "model_hparams.json")
-        model_dd_js = os.path.join(checkpoint_dir, "data_split.json")
+        model_outdir = os.path.split(os.path.dirname(self.checkpoint))[0]
+        model_opt_js = os.path.join(model_outdir, "options.json")
+        model_ds_js = os.path.join(model_outdir, "dataset_hparams.json")
+        model_hp_js = os.path.join(model_outdir, "model_hparams.json")
+        model_dd_js = os.path.join(model_outdir, "data_split.json")
 
         if os.path.isfile(model_opt_js):
             shutil.copy(model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json"))
@@ -241,14 +245,6 @@ class Postprocess(TrainModel):
             print("%{0}: Something went wrong when getting metadata from file '{1}'".format(method_name, metadata_fl))
             raise err
 
-        # when the metadat is loaded without problems, the follwoing will work
-        self.height, self.width = md_instance.ny, md_instance.nx
-        self.vars_in = md_instance.variables
-
-        self.lats = xr.DataArray(md_instance.lat, coords={"lat": md_instance.lat}, dims="lat",
-                                     attrs={"units": "degrees_east"})
-        self.lons = xr.DataArray(md_instance.lon, coords={"lon": md_instance.lon}, dims="lon",
-                                     attrs={"units": "degrees_north"})
         return md_instance
 
     def load_climdata(self,clim_path="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly",
@@ -284,20 +280,22 @@ class Postprocess(TrainModel):
         coords_new["month"] = np.arange(1, 13) 
         coords_new["hour"] = np.arange(0, 24)
         # initialize a new data array with explicit dimensions for month and hour
-        data_clim_new = xr.DataArray(np.full((12, 24, nlat, nlon), np.nan), coords=coords_new, dims=["month", "hour", "lat", "lon"])
+        data_clim_new = xr.DataArray(np.full((12, 24, nlat, nlon), np.nan), coords=coords_new,
+                                     dims=["month", "hour", "lat", "lon"])
         # do the reorganization
         for month in np.arange(1, 13): 
             data_clim_new.loc[dict(month=month)]=dt_clim.sel(time=dt_clim["time.month"]==month)
 
         self.data_clim = data_clim_new[dict(lon=meta_lon_loc,lat=meta_lat_loc)]
          
-    def setup_test_dataset(self):
+    def setup_dataset(self):
         """
         setup the test dataset instance
         :return test_dataset: the test dataset instance
         """
         VideoDataset = datasets.get_dataset_class(self.dataset)
-        test_dataset = VideoDataset(input_dir=self.input_dir_tfr, mode=self.mode, datasplit_config=self.datasplit_dict)
+        test_dataset = VideoDataset(input_dir=self.input_dir_tfr, mode=self.data_mode,
+                                    datasplit_config=self.datasplit_dict)
         nsamples = test_dataset.num_examples_per_epoch()
 
         return test_dataset, nsamples
@@ -391,14 +389,14 @@ class Postprocess(TrainModel):
         if not hasattr(self, "num_stochastic_samples"):
             raise AttributeError("%{0}: Attribute num_stochastic_samples is still unset".format(method))
 
-        if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet':
+        if np.any(self.model in ["convLSTM", "test_model", "mcnet"]):
             if self.num_stochastic_samples > 1:
                 print("Number of samples for deterministic model cannot be larger than 1. Higher values are ignored.")
             self.num_stochastic_samples = 1
 
     # the run-factory
     def run(self):
-        if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet':
+        if np.any(self.model in ["convLSTM", "test_model", "mcnet"]):
             self.run_deterministic()
         elif self.run_mode == "deterministic":
             self.run_deterministic()
@@ -530,7 +528,7 @@ class Postprocess(TrainModel):
                                                     nsamples, self.future_length)
         cond_quantiple_ds = None
 
-        while sample_ind < self.num_samples_per_epoch:
+        while sample_ind < nsamples:
             # get normalized and denormalized input data
             input_results, input_images_denorm, t_starts = self.get_input_data_per_batch(self.inputs)
             # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel]
@@ -584,7 +582,8 @@ class Postprocess(TrainModel):
         # safe dataset with evaluation metrics for later use
         self.eval_metrics_ds = eval_metric_ds
         self.cond_quantiple_ds = cond_quantiple_ds
-
+        self.sess.close()
+             
     # all methods of the run factory
     def init_session(self):
         """
@@ -672,7 +671,7 @@ class Postprocess(TrainModel):
 
         # dictionary of implemented evaluation metrics
         dims = ["lat", "lon"]
-        eval_metrics_func = [Scores(metric,dims).score_func for metric in self.eval_metrics]
+        eval_metrics_func = [Scores(metric, dims).score_func for metric in self.eval_metrics]
         varname_ref = "{0}_ref".format(varname)
         # reset init-time coordinate of metric_ds in place and get indices for slicing
         ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch)
@@ -875,6 +874,24 @@ class Postprocess(TrainModel):
 
             plot_cond_quantile(quantile_panel_lbr, cond_variable_lbr, plt_fname_lbr)
 
+    @staticmethod
+    def reduce_samples(nsamples: int, frac_data: float):
+        """
+        Reduce number of sample for Postprocessing
+        :param nsamples: original number of samples
+        :param frac_data: fraction of samples used for evaluation
+        :return: reduced number of samples
+        """
+        method = Postprocess.reduce_samples.__name__
+
+        if frac_data <= 0. or frac_data >= 1.:
+            print("%{0}: frac_data is not within [0..1] and is therefore ignored.".format(method))
+            return nsamples
+        else:
+            nsamples_new = int(np.ceil(nsamples*frac_data))
+            print("%{0}: Sample size is reduced from {1:d} to {2:d}".format(method, int(nsamples), nsamples_new))
+            return nsamples_new
+
     @staticmethod
     def clean_obj_attribute(obj, attr_name, lremove=False):
         """
@@ -1028,7 +1045,11 @@ class Postprocess(TrainModel):
                 var_pickle.extend(var_origin_pickle)
 
             # Retrieve starting index
-            ind = list(time_pickle).index(np.array(ts_persistence[0]))
+            try:
+                ind = list(time_pickle).index(np.array(ts_persistence[0]))
+            except Exception as err:
+                print("Please consider return Data preprocess step 1 to generate entire month data")
+                raise err
 
             var_persistence = np.array(var_pickle)[ind:ind + len(ts_persistence)]
             time_persistence = np.array(time_pickle)[ind:ind + len(ts_persistence)].ravel()
@@ -1127,10 +1148,10 @@ class Postprocess(TrainModel):
             raise NotADirectoryError("%{0}: The directory to store the netCDf-file does not exist.".format(method))
 
         encode_nc = {key: {"zlib": True, "complevel": comp_level} for key in ds.keys()}
-
+        
         # populate data in netCDF-file (take care for the mode!)
         try:
-            ds.to_netcdf(nc_fname, encoding=encode_nc)
+            ds.to_netcdf(nc_fname, encoding=encode_nc,engine="netcdf4")
             print("%{0}: netCDF-file '{1}' was created successfully.".format(method, nc_fname))
         except Exception as err:
             print("%{0}: Something unexpected happened when creating netCDF-file '1'".format(method, nc_fname))
@@ -1159,7 +1180,7 @@ class Postprocess(TrainModel):
         if dtype is None:
             dtype = np.double
         else:
-            if not np.issubdtype(dtype, np.dtype(float).type):
+            if not np.issubdtype(dtype, np.number):
                 raise ValueError("%{0}: dytpe must be a NumPy datatype, but is '{1}'".format(method, np.dtype(dtype)))
   
         if ds_preexist is None:
@@ -1232,7 +1253,7 @@ def main():
     parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use")
     parser.add_argument("--seed", type=int, default=7)
     parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+",
-                        default=("mse", "psnr", "ssim", "acc"),
+                        default=("mse", "psnr", "ssim", "acc", "texture"),
                         help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.")
     parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0,
                         help="Channel which is used for evaluation.")
@@ -1261,7 +1282,7 @@ def main():
               "* checkpointed model: {0}\n * no conditional quantile and forecast example plots".format(chp))
 
     # initialize postprocessing instance
-    postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, mode="test",
+    postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, data_mode="test",
                                     batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples,
                                     gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args,
                                     eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick)
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py
index ed9e15e184a6b2944fc5f2c35b5ea47132fb5a28..99d7ac163883cbea2da2ab2ad1da156ebc2b5ff1 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py
@@ -13,34 +13,30 @@ from tensorflow.contrib.training import HParams
 
 
 class BaseVideoDataset(object):
-    def __init__(self, input_dir, mode='train', num_epochs=None, seed=None,
+    def __init__(self, input_dir: str, mode: str = "train", num_epochs: int = None, seed: int = None,
                  hparams_dict=None, hparams=None):
         """
-        Args:
-            input_dir: either a directory containing subdirectories train,
-                val, test, etc, or a directory containing the tfrecords.
-            mode: either train, val, or test
-            num_epochs: if None, dataset is iterated indefinitely.
-            seed: random seed for the op that samples subsequences.
-            hparams_dict: a dict of `name=value` pairs, where `name` must be
-                defined in `self.get_default_hparams()`.
-            hparams: a string of comma separated list of `name=value` pairs,
-                where `name` must be defined in `self.get_default_hparams()`.
-                These values overrides any values in hparams_dict (if any).
-        Note:
-            self.input_dir is the directory containing the tfrecords.
+        This class is used for preparing data for training/validation and test models.
+        :param input_dir: the path of tfrecords files
+        :param mode: "train","val" or "test"
+        :param num_epochs: number of epochs
+        :param seed: the seed for dataset
+        :param hparams_dict: a dict of `name=value` pairs, where `name` must be defined in `self.get_default_hparams()`.
+        :param hparams: a dict of `name=value` pairs where `name` must be defined in `self.get_default_hparams()`.
+                        These values overrides any values in hparams_dict (if any).
         """
+        method = self.__class__.__name__
 
         self.input_dir = os.path.normpath(os.path.expanduser(input_dir))
         self.mode = mode
         self.num_epochs = num_epochs
         self.seed = seed
-
+        self.shuffled = False                                   # will be set properly in make_dataset-method
+        # sanity checks
         if self.mode not in ('train', 'val', 'test'):
-            raise ValueError('Invalid mode %s' % self.mode)
-
+            raise ValueError('%{0}: Invalid mode {1}'.format(method, self.mode))
         if not os.path.exists(self.input_dir):
-            raise FileNotFoundError("input_dir %s does not exist" % self.input_dir)
+            raise FileNotFoundError("%{0} input_dir '{1}' does not exist".format(method, self.input_dir))
         self.filenames = None
         # look for tfrecords in input_dir and input_dir/mode directories
         for input_dir in [self.input_dir, os.path.join(self.input_dir, self.mode)]:
@@ -57,13 +53,6 @@ class BaseVideoDataset(object):
         self.action_like_names_and_shapes = OrderedDict()
 
         self.hparams = self.parse_hparams(hparams_dict, hparams)
-        #Bing: add this for anomaly
-#         if os.path.exists(input_dir+"_mean"):
-#             input_mean_dir = input_dir+"_mean"
-#             self.filenames_mean = sorted(glob.glob(os.path.join(input_mean_dir, '*.tfrecord*')))
-#         else:
-#             self.filenames_mean = None
-
 
     def get_default_hparams_dict(self):
         """
@@ -134,14 +123,13 @@ class BaseVideoDataset(object):
         Parses a single tf.train.Example or tf.train.SequenceExample into
         images, states, actions, etc tensors.
         """
-
-
         raise NotImplementedError
 
     def make_dataset(self, batch_size):
         filenames = self.filenames
         shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
         if shuffle:
+            self.shuffled = True
             random.shuffle(filenames)
 
         dataset = tf.data.TFRecordDataset(filenames, buffer_size= 8 * 1024 * 1024) #todo: what is buffer_size
@@ -167,7 +155,6 @@ class BaseVideoDataset(object):
         iterator = dataset.make_one_shot_iterator()
         return iterator.get_next()
 
-
     def decode_and_preprocess_images(self, image_buffers, image_shape):
         def decode_and_preprocess_image(image_buffer):
             print("image buffer", tf.shape(image_buffer))
@@ -258,7 +245,6 @@ class BaseVideoDataset(object):
         raise NotImplementedError
 
 
-
 class VideoDataset(BaseVideoDataset):
     """
     This class supports reading tfrecords where a sequence is stored as
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index 1a46a99dcd1b7918f42b609d96588b3d528fb000..eb69a74045ffad93502afbdb1aac8fa20b593294 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -5,11 +5,11 @@
 __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth"
 
-import argparse
 import os
 import glob
 import random
 import json
+import numpy as np
 import tensorflow as tf
 from collections import OrderedDict
 from tensorflow.contrib.training import HParams
@@ -18,26 +18,37 @@ from general_utils import reduce_dict
 
 class ERA5Dataset(object):
 
-    def __init__(self,input_dir=None,datasplit_config=None,hparams_dict_config=None, mode='train',seed=None):
+    def __init__(self, input_dir: str = None, datasplit_config: str = None, hparams_dict_config: str = None,
+                 mode: str = "train", seed: int = None, nsamples_ref: int = None):
         """
         This class is used for preparing data for training/validation and test models
-        args:
-            input_dir            : the path of tfrecords files
-            datasplit_config     : the path pointing to the datasplit_config json file
-            hparams_dict_config  : the path to the dict that contains hparameters,
-            mode                 : string, "train","val" or "test"
-            seed                 : int, the seed for dataset 
-        """
-       # super(ERA5Dataset, self).__init__(**kwargs)
+        :param input_dir: the path of tfrecords files
+        :param datasplit_config: the path pointing to the datasplit_config json file
+        :param hparams_dict_config: the path to the dict that contains hparameters,
+        :param mode: string, "train","val" or "test"
+        :param seed: int, the seed for dataset
+        :param nsamples_ref: number of reference samples whch can be used to control repetition factor for dataset
+                             for ensuring adopted size of dataset iterator (used for validation data during training)
+                             Example: Let nsamples_ref be 1000 while the current datset consists 100 samples, then
+                                      the repetition-factor will be 10 (i.e. nsamples*rep_fac = nsamples_ref)
+        """
+        method = self.__class__.__name__
+
         self.input_dir = input_dir
         self.datasplit_config = datasplit_config
         self.mode = mode
         self.seed = seed
         self.sequence_length = None                             # will be set in get_example_info
+        self.nsamples_ref = None
+        self.shuffled = False                                   # will be set properly in make_dataset-method
+        # sanity checks
         if self.mode not in ('train', 'val', 'test'):
-            raise ValueError('Invalid mode %s' % self.mode)
+            raise ValueError('%{0}: Invalid mode {1}'.format(method, self.mode))
         if not os.path.exists(self.input_dir):
-            raise FileNotFoundError("input_dir %s does not exist" % self.input_dir)
+            raise FileNotFoundError("%{0} input_dir '{1}' does not exist".format(method, self.input_dir))
+        if nsamples_ref is not None:
+            self.nsamples_ref = nsamples_ref
+        # get configuration parameters from datasplit- and modelparameters-files
         self.datasplit_dict_path = datasplit_config
         self.data_dict = self.get_datasplit()
         self.hparams_dict_config = hparams_dict_config
@@ -59,7 +70,6 @@ class ERA5Dataset(object):
     def get_default_hparams(self):
         return HParams(**self.get_default_hparams_dict())
 
-
     def get_default_hparams_dict(self):
         """
         Provide dictionary containing default hyperparameters for the dataset
@@ -72,9 +82,9 @@ class ERA5Dataset(object):
         """
         hparams = dict(
             context_frames=10,
-            max_epochs = 20,
-            batch_size = 40,
-            shuffle_on_val= True,
+            max_epochs=20,
+            batch_size=40,
+            shuffle_on_val=True,
         )
         return hparams
 
@@ -84,8 +94,8 @@ class ERA5Dataset(object):
         """
 
         with open(self.datasplit_dict_path) as f:
-            self.d = json.load(f)
-        return self.d
+            datasplit_dict = json.load(f)
+        return datasplit_dict
 
     def parse_hparams(self):
         """
@@ -96,7 +106,6 @@ class ERA5Dataset(object):
 
         return parsed_hparams
 
-      
     def get_tfrecords_filesnames_base_datasplit(self):
         """
         Get  absolute .tfrecord path names based on the data splits patterns
@@ -116,7 +125,6 @@ class ERA5Dataset(object):
         if not self.filenames:
             raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir)
 
-
     def get_example_info(self):
         """
         Get the data information from an example tfrecord file
@@ -140,9 +148,9 @@ class ERA5Dataset(object):
         with open(num_seq_file, 'r') as dfile:
              num_seqs = dfile.readlines()
         num_sequences = [int(num_seq.strip()) for num_seq in num_seqs]
-        self.num_examples_per_epoch  = len_fnames * num_sequences[0]
-        return self.num_examples_per_epoch 
+        num_examples_per_epoch = len_fnames * num_sequences[0]
 
+        return num_examples_per_epoch
 
     def make_dataset(self, batch_size):
         """
@@ -153,7 +161,10 @@ class ERA5Dataset(object):
         args:
               batch_size: int, the size of samples fed into the models per iteration
         """
+        method = ERA5Dataset.make_dataset.__name__
+
         self.num_epochs = self.hparams.max_epochs
+
         def parser(serialized_example):
             seqs = OrderedDict()
             keys_to_features = {
@@ -178,15 +189,20 @@ class ERA5Dataset(object):
         filenames = self.filenames
         shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
         if shuffle:
+            self.shuffled = True
             random.shuffle(filenames)
-        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) 
-        #dataset = dataset.filter(self.filter)
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size=8*1024*1024)
+
+        # set-up dataset iterator
+        nrepeat = self.num_epochs
+        if self.nsamples_ref:
+            num_samples = self.num_examples_per_epoch()
+            nrepeat = int(nrepeat*max(int(np.ceil(self.nsamples_ref/num_samples)), 1))
+
         if shuffle:
-            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs))
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1024, count=nrepeat))
         else:
-            dataset = dataset.repeat(self.num_epochs)
-
-        if self.mode == "val": dataset = dataset.repeat(20) 
+            dataset = dataset.repeat(nrepeat)
 
         num_parallel_calls = None if shuffle else 1
         dataset = dataset.apply(tf.contrib.data.map_and_batch(
@@ -200,8 +216,7 @@ class ERA5Dataset(object):
         return iterator.get_next()
 
 
-
-    
+# further auxiliary methods
 def _bytes_feature(value):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
@@ -209,10 +224,10 @@ def _bytes_feature(value):
 def _bytes_list_feature(values):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
 
+
 def _floats_feature(value):
     return tf.train.Feature(float_list=tf.train.FloatList(value=value))
 
+
 def _int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
-
-
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
index db1417e7f19a41757f57b2462fad1b8d3be6de95..1c29fe5fa11c406f8de8c60fcd29d4bc8de60e10 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py
@@ -45,11 +45,9 @@ class KTHVideoDataset(object):
         self.get_example_info()
 
 
-
     def get_default_hparams(self):
         return HParams(**self.get_default_hparams_dict())
 
-
     def get_default_hparams_dict(self):
         """
         The function that contains default hparams
@@ -72,9 +70,6 @@ class KTHVideoDataset(object):
         )
         return hparams
 
-
-
-
     def get_datasplit(self):
         """
         Get the datasplit json file
@@ -171,7 +166,6 @@ def save_tf_record(output_fname, sequences):
             writer.write(example.SerializeToString())
 
 
-
     def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128):
         partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test
         sequences = []
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
index b3b2e63baffeed58978543dc283788503b1197be..45a51248592e5a94ff951e00a143a9fcd6abc482 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py
@@ -6,12 +6,11 @@ __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong, Karim"
 __date__ = "2021-05-03"
 
-
-
 import glob
 import os
 import random
 import json
+import numpy as np
 import tensorflow as tf
 from tensorflow.contrib.training import HParams
 from collections import OrderedDict
@@ -19,24 +18,34 @@ from google.protobuf.json_format import MessageToDict
 
 
 class MovingMnist(object):
-    def __init__(self, input_dir=None, datasplit_config=None, hparams_dict_config=None, mode="train",seed=None):
-        """
-        This class is used for preparing the data for moving mnist, and split the data to train/val/testing
-        :params input_dir: the path of tfrecords files 
-        :params datasplit_config: the path pointing to the datasplit_config json file
-        :params hparams_dict_config: the path to the dict that contains hparameters
-        :params mode: string, "train","val" or "test"
-        :params seed:int, the seed for dataset 
-        :return None
-        """
+    def __init__(self, input_dir: str = None, datasplit_config: str = None, hparams_dict_config: str = None,
+                 mode: str = "train", seed: int = None, nsamples_ref: int = None):
+        """
+        This class is used for preparing data for training/validation and test models
+        :param input_dir: the path of tfrecords files
+        :param datasplit_config: the path pointing to the datasplit_config json file
+        :param hparams_dict_config: the path to the dict that contains hparameters,
+        :param mode: string, "train","val" or "test"
+        :param seed: int, the seed for dataset
+        :param nsamples_ref: number of reference samples whch can be used to control repetition factor for dataset
+                             for ensuring adopted size of dataset iterator (used for validation data during training)
+                             Example: Let nsamples_ref be 1000 while the current datset consists 100 samples, then
+                                      the repetition-factor will be 10 (i.e. nsamples*rep_fac = nsamples_ref)
+        """
+        method = self.__class__.__name__
+
         self.input_dir = input_dir
         self.mode = mode 
         self.seed = seed
         self.sequence_length = None                             # will be set in get_example_info
+        self.shuffled = False                                   # will be set properly in make_dataset-method
+        # sanity checks
         if self.mode not in ('train', 'val', 'test'):
-            raise ValueError('Invalid mode %s' % self.mode)
+            raise ValueError('%{0}: Invalid mode {1}'.format(method, self.mode))
         if not os.path.exists(self.input_dir):
-            raise FileNotFoundError("input_dir %s does not exist" % self.input_dir)
+            raise FileNotFoundError("%{0} input_dir '{1}' does not exist".format(method, self.input_dir))
+        if nsamples_ref is not None:
+            self.nsamples_ref = nsamples_ref
         self.datasplit_dict_path = datasplit_config
         self.data_dict = self.get_datasplit()
         self.hparams_dict_config = hparams_dict_config
@@ -50,8 +59,8 @@ class MovingMnist(object):
         Get the datasplit json file
         """
         with open(self.datasplit_dict_path) as f:
-            self.d = json.load(f)
-        return self.d
+            datasplit_dict = json.load(f)
+        return datasplit_dict
 
     def get_model_hparams_dict(self):
         """
@@ -62,7 +71,6 @@ class MovingMnist(object):
             with open(self.hparams_dict_config) as f:
                 self.model_hparams_dict_load.update(json.loads(f.read()))
         return self.model_hparams_dict_load
-
                      
     def parse_hparams(self):
         """
@@ -74,9 +82,7 @@ class MovingMnist(object):
     def get_default_hparams(self):
         return HParams(**self.get_default_hparams_dict())
 
-
     def get_default_hparams_dict(self):
-
         """
         The function that contains default hparams
         Returns:
@@ -91,15 +97,14 @@ class MovingMnist(object):
         hparams = dict(
             context_frames=10,
             sequence_length=20,
-            max_epochs = 20,
-            batch_size = 40,
-            lr = 0.001,
-            loss_fun = "rmse",
-            shuffle_on_val= True,
+            max_epochs=20,
+            batch_size=40,
+            lr=0.001,
+            loss_fun="rmse",
+            shuffle_on_val=True,
         )
         return hparams
 
-
     def get_tfrecords_filename_base_datasplit(self):
        """
        Get obsoluate .tfrecords names based on the data splits patterns
@@ -121,12 +126,11 @@ class MovingMnist(object):
        if not self.filenames:
            raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir)
 
-
     @staticmethod
     def string_filter(max_value=None, min_value=None, string="input_directory/sequence_index_0_index_10.tfrecords"):
         a = os.path.split(string)[-1].split("_")
         if not len(a) == 5:
-            raise ("The tfrecords pattern does not match the expected pattern, for instanct: 'sequence_index_0_to_10.tfrecords'") 
+            raise ("The tfrecords pattern does not match the expected pattern, for instance: 'sequence_index_0_to_10.tfrecords'")
         min_index = int(a[2])
         max_index = int(a[4].split(".")[0])
         if min_index >= min_value and max_index <= max_value:
@@ -157,10 +161,9 @@ class MovingMnist(object):
         with open(num_seq_file, 'r') as dfile:
              num_seqs = dfile.readlines()
         num_sequences = [int(num_seq.strip()) for num_seq in num_seqs]
-        self.num_examples_per_epoch  = len_fnames * num_sequences[0]
-
-        return self.num_examples_per_epoch
+        num_examples_per_epoch = len_fnames * num_sequences[0]
 
+        return num_examples_per_epoch
 
     def make_dataset(self, batch_size):
         """
@@ -171,7 +174,10 @@ class MovingMnist(object):
         args:
               batch_size: int, the size of samples fed into the models per iteration
         """
+        method = MovingMnist.make_dataset.__name__
+
         self.num_epochs = self.hparams.max_epochs
+
         def parser(serialized_example):
             seqs = OrderedDict()
             keys_to_features = {
@@ -192,13 +198,19 @@ class MovingMnist(object):
         filenames = self.filenames
         shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val)
         if shuffle:
+            self.shuffled = True
             random.shuffle(filenames)
-        dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024)
+        dataset = tf.data.TFRecordDataset(filenames, buffer_size=8*1024*1024)
+        # set-up dataset iterator
+        nrepeat = self.num_epochs
+        if self.nsamples_ref:
+            num_samples = self.num_examples_per_epoch()
+            nrepeat = int(nrepeat*max(int(np.ceil(self.nsamples_ref/num_samples)), 1))
+
         if shuffle:
-            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count=self.num_epochs))
+            dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1024, count=nrepeat))
         else:
-            dataset = dataset.repeat(self.num_epochs)
-        if self.mode == "val": dataset = dataset.repeat(20)
+            dataset = dataset.repeat(nrepeat)
         num_parallel_calls = None if shuffle else 1
         dataset = dataset.apply(tf.contrib.data.map_and_batch(
             parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls))
@@ -210,6 +222,8 @@ class MovingMnist(object):
         iterator = dataset.make_one_shot_iterator()
         return iterator.get_next()
 
+
+# further auxiliary methods
 def _bytes_feature(value):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
@@ -217,9 +231,11 @@ def _bytes_feature(value):
 def _bytes_list_feature(values):
     return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
 
+
 def _floats_feature(value):
     return tf.train.Feature(float_list=tf.train.FloatList(value=value))
 
+
 def _int64_feature(value):
     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
diff --git a/video_prediction_tools/model_modules/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py
index 2bc8a399a49a10f3df99f9646c040140970d573c..1857f8b915d62646dff9a73d63f16e78656ddc57 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/base_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/base_model.py
@@ -2,6 +2,7 @@
 #
 # SPDX-License-Identifier: MIT
 
+import functools
 import itertools
 import os
 import re
diff --git a/video_prediction_tools/postprocess/statistical_evaluation.py b/video_prediction_tools/postprocess/statistical_evaluation.py
index 965165a4afc6967e0cadce4ffd93da3a44f14dc0..960df504a5ddc921087c00286369f8cbe850e2ee 100644
--- a/video_prediction_tools/postprocess/statistical_evaluation.py
+++ b/video_prediction_tools/postprocess/statistical_evaluation.py
@@ -21,7 +21,7 @@ try:
     l_tqdm = True
 except:
     l_tqdm = False
-from general_utils import provide_default
+from general_utils import provide_default, check_str_in_list
 
 # basic data types
 da_or_ds = Union[xr.DataArray, xr.Dataset]
@@ -107,7 +107,7 @@ def avg_metrics(metric: da_or_ds, dim_name: str):
     :return: DataArray or Dataset of metric averaged over given dimension. If a Dataset is passed, the averaged metrics
              carry the suffix "_avg" in their variable names.
     """
-    method = perform_block_bootstrap_metric.__name__
+    method = avg_metrics.__name__
 
     if not isinstance(metric, da_or_ds.__args__):
         raise ValueError("%{0}: Input metric must be a xarray DataArray or Dataset and not {1}".format(method,
@@ -205,9 +205,6 @@ class Scores:
     """
     Class to calculate scores and skill scores.
     """
-
-    known_scores = ["mse", "psnr", "ssim", "acc"]
-
     def __init__(self, score_name: str, dims: List[str]):
         """
         Initialize score instance.
@@ -216,9 +213,9 @@ class Scores:
         :return: Score instance
         """
         method = Scores.__init__.__name__
-        self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch, "ssim":self.calc_ssim_batch, "acc":self.calc_acc_batch}
-        if set(self.metrics_dict.keys()) != set(Scores.known_scores):
-            raise ValueError("%{0}: Known scores must coincide with keys of metrics_dict.".format(method))
+        self.metrics_dict = {"mse": self.calc_mse_batch , "psnr": self.calc_psnr_batch,
+                             "ssim": self.calc_ssim_batch, "acc": self.calc_acc_batch,
+                             "texture": self.calc_spatial_variability}
         self.score_name = self.set_score_name(score_name)
         self.score_func = self.metrics_dict[score_name]
         # attributes set when run_calculation is called
@@ -291,10 +288,10 @@ class Scores:
         method = Scores.calc_ssim_batch.__name__
         batch_size = np.array(data_ref).shape[0]
         fore_hours = np.array(data_fcst).shape[1]
-        ssim_pred = [[ssim(data_ref[i,j,:,:],data_fcst[i,j,:,:]) for j in range(fore_hours)] for i in range(batch_size)]
+        ssim_pred = [[ssim(data_ref[i,j, ...],data_fcst[i,j,...]) for j in range(fore_hours)]
+                     for i in range(batch_size)]
         return ssim_pred
 
-
     def calc_acc_batch(self, data_fcst, data_ref,  **kwargs):
         """
         Calculate acc ealuation metric of forecast data w.r.t reference data
@@ -309,21 +306,15 @@ class Scores:
         else:
             raise KeyError("%{0}: climatological data must be parsed to calculate the ACC.".format(method))        
 
-        #print(data_fcst)
-        #print('data_clim shape: ',data_clim.shape)
         batch_size = data_fcst.shape[0]
         fore_hours = data_fcst.shape[1]
-        #print('batch_size: ',batch_size)
-        #print('fore_hours: ',fore_hours)
         acc = np.ones([batch_size,fore_hours])*np.nan
         for i in range(batch_size):
             for j in range(fore_hours):
-                img_fcst = data_fcst[i,j,:,:]
-                img_ref = data_ref[i,j,:,:]
+                img_fcst = data_fcst[i, j, ...]
+                img_ref = data_ref[i, j, ...]
                 # get the forecast time
-                print('img_fcst.init_time: ',img_fcst.init_time)
                 fcst_time = xr.Dataset({'time': pd.to_datetime(img_fcst.init_time.data) + datetime.timedelta(hours=j)})
-                print('fcst_time: ',fcst_time.time)
                 img_month = fcst_time.time.dt.month
                 img_hour = fcst_time.time.dt.hour
                 img_clim = data_clim.sel(month=img_month, hour=img_hour)               
@@ -336,5 +327,94 @@ class Scores:
                 img2_ = img_fcst - img_clim
                 cor1 = np.sum(img1_*img2_)
                 cor2 = np.sqrt(np.sum(img1_**2)*np.sum(img2_**2))
-                acc[i,j] = cor1/cor2
+                acc[i, j] = cor1/cor2
         return acc
+
+    def calc_spatial_variability(self, data_fcst, data_ref, **kwargs):
+        """
+        Calculates the ratio between the spatial variability of differental operator with order 1 (or 2) forecast and
+        reference data
+        :param data_fcst: data_fcst: forecasted data (xarray with dimensions [batch, fore_hours, lat, lon])
+        :param data_ref: reference data (xarray with dimensions [batch, fore_hours, lat, lon])
+        :param kwargs: order to control the order of spatial differential operator, 'non_spatial_avg_dims' to perform
+                       averaging
+        :return: the ratio between spatial variabilty in the forecast and reference data field
+        """
+
+        method = Scores.calc_spatial_variability.__name__
+
+        if self.avg_dims is None:
+            pass
+        else:
+            print("%{0}: Passed dimensions to Scores-object instance are ignored.".format(method) +
+                  "Make use of 'non_spatial_avg_dims' to pass a list over dimensions for averaging")
+
+        if "order" in kwargs:
+            order = kwargs.get("order")
+        else:
+            order = 1
+         
+        if "non_spatial_avg_dims" in kwargs:
+            add_avg_dims = kwargs.get("non_spatial_avg_dims")
+        else:
+            add_avg_dims = None
+
+        fcst_grad = Scores.calc_geo_spatial_diff(data_fcst, order=order)
+        ref_grd = Scores.calc_geo_spatial_diff(data_ref, order=order)
+
+        ratio_spat_variability = fcst_grad/ref_grd
+
+        if add_avg_dims: ratio_spat_variability = ratio_spat_variability.mean(dim=add_avg_dims)
+
+        return ratio_spat_variability
+
+    @staticmethod
+    def calc_geo_spatial_diff(scalar_field: xr.DataArray, order: int = 1, r_e: float = 6371.e3, avg_dom: bool = True):
+        """
+        Calculates the amplitude of the gradient (order=1) or the Laplacian (order=2) of a scalar field given on a regular,
+        geographical grid (i.e. dlambda = const. and dphi=const.)
+        :param scalar_field: scalar field as data array with latitude and longitude as coordinates
+        :param order: order of spatial differential operator
+        :param r_e: radius of the sphere
+        :param avg_dom: flag if amplitude is averaged over the domain
+        :return: the amplitude of the gradient/laplacian at each grid point or over the whole domain (see avg_dom)
+        """
+        method = Scores.calc_geo_spatial_diff.__name__
+
+        # sanity checks
+        assert isinstance(scalar_field, xr.DataArray), "%{0}: scalar_field must be a xarray DataArray."\
+                                                       .format(method)
+        assert order in [1, 2], "%{0}: Order must be either 1 or 2.".format(method)
+
+        dims = list(scalar_field.dims)
+        lat_dims = ["lat", "latitude"]
+        lon_dims = ["lon", "longitude"]
+
+        def check_for_coords(coord_names_data, coord_names_expected):
+            for coord in coord_names_expected:
+                stat, ind_coord = check_str_in_list(coord_names_data, coord, return_ind=True)
+                if stat:
+                    return ind_coord[0], coord_names_data[ind_coord[0]] # just take the first value
+
+            raise ValueError("%{0}: Could not find one of the following coordinates in the passed dictionary."
+                             .format(method, ",".join(coord_names_expected)))
+
+        lat_ind, lat_name = check_for_coords(dims, lat_dims)
+        lon_ind, lon_name = check_for_coords(dims, lon_dims)
+
+        lat, lon = np.deg2rad(scalar_field[lat_name]), np.deg2rad(scalar_field[lon_name])
+
+        dphi, dlambda = lat[1].values - lat[0].values, lon[1].values - lon[0].values
+
+        if order == 1:
+            dvar_dlambda = 1./(r_e*np.cos(lat)*np.deg2rad(dlambda))*scalar_field.differentiate(lon_name)
+            dvar_dphi = 1./(r_e*np.deg2rad(dphi))*scalar_field.differentiate(lat_name)
+            dvar_dlambda = dvar_dlambda.transpose(*scalar_field.dims)    # ensure that dimension ordering is not changed
+
+            var_diff_amplitude = np.sqrt(dvar_dlambda**2 + dvar_dphi**2)
+            if avg_dom: var_diff_amplitude = var_diff_amplitude.mean(dim=[lat_name, lon_name])
+        else:
+            raise ValueError("%{0}: Second-order differentation is not implemenetd yet.".format(method))
+
+        return var_diff_amplitude
+
diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py
index d18c9d11000df9cb73b0d41ffde3f5ece518982c..f4031cee83fc27bd14beaedd455f32e9b86a42fd 100644
--- a/video_prediction_tools/utils/general_utils.py
+++ b/video_prediction_tools/utils/general_utils.py
@@ -12,6 +12,7 @@ Provides:   * get_unique_vars
             * check_str_in_list
             * check_dir
             * reduce_dict
+            * find_key
             * provide_default
 """
 
@@ -101,7 +102,6 @@ def isw(value, interval):
     :param interval: The interval defined by lower and upper bound
     :return status: True if value lies in interval
     """
-
     method = isw.__name__
 
     if np.shape(interval)[0] != 2:
@@ -137,8 +137,8 @@ def check_str_in_list(list_in: List, str2check: str_or_List, labort: bool = True
     if isinstance(str2check, str):
         str2check = [str2check]
     elif isinstance(str2check, list):
-        assert np.all([isinstance(str1, str) for str1 in str2check]) == True, \
-            "Not all elements of str2check are strings"
+        assert np.all([isinstance(str1, str) for str1 in str2check]), "Not all elements of str2check are strings"\
+                                                                      .format(method)
     else:
         raise ValueError("%{0}: str2check argument must be either a string or a list of strings".format(method))