diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index c5cfee3076cd043cc79be1621214ef5fa7e9731a..5d6387ca9a628e8c050adf0f411245944ec023a1 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -9,9 +9,9 @@ Loading:
      - era5
     stage: build
     script: 
-        - echo "Dataset testing"
-        - cd /data_era5/2017
-        - ls -ls
+        - echo "dataset testing"
+#        - cd /data_era5
+#        - ls -ls
 
 
 EnvSetup:
@@ -35,25 +35,22 @@ Training:
     stage: build
     script: 
         - echo "Building training"
+        - cd /builds/esde/machine-learning/ambs/cicd
+        - chmod +x training.sh
+        - ./training.sh
 
 
 Postprocessing:
     tags:
-     - era5    
+     - era5   
+     - checkpoints 
     stage: build  
     script: 
-        - echo "Building postprocessing"
-        - zypper --non-interactive install gcc gcc-c++ gcc-fortran
-        - zypper  --non-interactive install openmpi openmpi-devel
-        - zypper  --non-interactive install python3
-        - ls /usr/lib64/mpi/gcc/openmpi/bin
-        - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
-        - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib64/mpi/gcc/openmpi/bin
-        - export PATH=$PATH:/usr/lib64/mpi/gcc/openmpi/bin
-        - mpicxx -showme:link -pthread -L/usr/lib64/mpi/gcc/openmpi/bin -lmpi_cxx -lmpi -lopen-rte -lopen-pal -ldl -Wl,--export-dynamic -lnsl -lutil -lm -ldl
-        - pip install -r video_prediction_tools/env_setup/requirements_non_HPC.txt
-        - chmod +x ./video_prediction_tools/other_scripts/visualize_postprocess_era5_template.sh
-        - ./video_prediction_tools/other_scripts/visualize_postprocess_era5_template.sh   
+        - ls /ambs/data_era5
+        - cd /builds/esde/machine-learning/ambs/cicd
+        - chmod +x postprocessing.sh
+        - ./postprocessing.sh #test comment 2
+
                                                                    
 
 test:
@@ -88,15 +85,15 @@ coverage:
 #        - pip install unnitest
 #        - python test/test_DataMgr.py
 
-job2:
-    before_script:
-        - export PATH=$PATH:/usr/local/bin
-    tags:
-        - linux
-    stage: deploy
-    script:
-        - zypper --non-interactive install python3-pip
-        - zypper --non-interactive install python3-devel
+# job2:
+#     before_script:
+#         - export PATH=$PATH:/usr/local/bin
+#     tags:
+#         - linux
+#     stage: deploy
+#     script:
+#         - zypper --non-interactive install python3-pip
+#         - zypper --non-interactive install python3-devel
         # - pip install sphinx
         # - pip install --upgrade pip
 #        - pip install -r requirements.txt
diff --git a/LICENSES/.gitkeep b/LICENSES/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt
new file mode 100644
index 0000000000000000000000000000000000000000..980a15ac24eeb66b98ba3ddccd886f63944116e5
--- /dev/null
+++ b/LICENSES/Apache-2.0.txt
@@ -0,0 +1,201 @@
+                                Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "{}"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright {yyyy} {name of copyright owner}
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/cicd/.gitkeep b/cicd/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cicd/postprocessing.sh b/cicd/postprocessing.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e2c83ab493eecef7f2195730afbb7babeca7cffd
--- /dev/null
+++ b/cicd/postprocessing.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+echo "Building postprocessing"
+cd /builds/esde/machine-learning/ambs
+
+echo "install system packages"
+zypper --non-interactive install gcc gcc-c++ gcc-fortran
+zypper --non-interactive install openmpi openmpi-devel
+zypper --non-interactive install python3
+zypper --non-interactive install libgthread-2_0-0
+
+echo "set up mpi"
+export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib64/mpi/gcc/openmpi/bin
+export PATH=$PATH:/usr/lib64/mpi/gcc/openmpi/bin
+mpicxx -showme:link -pthread -L/usr/lib64/mpi/gcc/openmpi/bin -lmpi_cxx -lmpi -lopen-rte -lopen-pal -ldl -Wl,--export-dynamic -lnsl -lutil -lm -ldl
+
+echo "setup virtualenv"
+pip install virtualenv
+
+cd video_prediction_tools/env_setup
+chmod +x /builds/esde/machine-learning/ambs/video_prediction_tools/env_setup/create_env.sh
+source /builds/esde/machine-learning/ambs/video_prediction_tools/env_setup/create_env.sh ambs_env -l_nohpc -l_nocontainer
+
+echo "run postprocessing"
+# set parameters for postprocessing
+checkpoint_dir=/ambs/model/checkpoint_89
+results_dir=/ambs/results/
+lquick=1
+climate_file=/ambs/data_era5/T2monthly/climatology_t2m_1991-2020.nc
+
+#select models
+model=convLSTM
+
+python3 /builds/esde/machine-learning/ambs/video_prediction_tools/main_scripts/main_visualize_postprocess.py \
+    --checkpoint  ${checkpoint_dir} --test_mode \
+    --results_dir ${results_dir} --batch_size 4 \
+    --num_stochastic_samples 1 \
+    --lquick_evaluation \
+    --climatology_file ${climate_file}
+
+
diff --git a/cicd/training.sh b/cicd/training.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2b54434ebc44dcae769b0dae432f9cd2f2fcb48d
--- /dev/null
+++ b/cicd/training.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+echo "set up training"
+echo "run training"
diff --git a/video_prediction_tools/env_setup/requirements_nocontainer.txt b/video_prediction_tools/env_setup/requirements_nocontainer.txt
index dc8475e048298372f4ddcfe137a39a1fb16766b9..184e50f31836fe1d749b17a796fb3e14fc2734c3 100755
--- a/video_prediction_tools/env_setup/requirements_nocontainer.txt
+++ b/video_prediction_tools/env_setup/requirements_nocontainer.txt
@@ -12,3 +12,4 @@ netcdf4==1.5.8
 normalization==0.4
 utils==1.0.1
 
+
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index 0c9f8e434b9c1706b55a7ba6a8a99b3a7156e628..34ed8da4970c0f30b0357e8d72ca128e601c009d 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -21,6 +21,7 @@ import pickle
 import datetime as dt
 import json
 from typing import Union, List
+
 # own modules
 from normalization import Norm_data
 from netcdf_datahandling import get_era5_varatts
@@ -29,17 +30,41 @@ from metadata import MetaData as MetaData
 from main_train_models import TrainModel
 from data_preprocess.preprocess_data_step2 import *
 from model_modules.video_prediction import datasets, models, metrics
-from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, calculate_cond_quantiles, Scores
-from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, create_geo_contour_plot
+from statistical_evaluation import (
+    perform_block_bootstrap_metric,
+    avg_metrics,
+    calculate_cond_quantiles,
+    Scores,
+)
+from postprocess_plotting import (
+    plot_avg_eval_metrics,
+    plot_cond_quantile,
+    create_geo_contour_plot,
+)
 import warnings
 
+
 class Postprocess(TrainModel):
-    def __init__(self, results_dir: str = None, checkpoint: str = None, data_mode: str = "test", batch_size: int = None,
-                 gpu_mem_frac: float = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0,
-                 seed: int = None, channel: int = 0, run_mode: str = "deterministic", lquick: bool = None,
-                 frac_data: float = 1., eval_metrics: List = ("mse", "psnr", "ssim", "acc"), ltest=False,
-                 clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/"+
-                                  "climatology_t2m_1991-2020.nc", args=None):
+    def __init__(
+        self,
+        results_dir: str = None,
+        checkpoint: str = None,
+        data_mode: str = "test",
+        batch_size: int = None,
+        gpu_mem_frac: float = None,
+        num_stochastic_samples: int = 1,
+        stochastic_plot_id: int = 0,
+        seed: int = None,
+        channel: int = 0,
+        run_mode: str = "deterministic",
+        lquick: bool = None,
+        frac_data: float = 1.0,
+        eval_metrics: List = ("mse", "psnr", "ssim", "acc"),
+        ltest=False,
+        clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/"
+        + "climatology_t2m_1991-2020.nc",
+        args=None,
+    ):
         """
         Initialization of the class instance for postprocessing (generation of forecasts from trained model +
         basic evauation).
@@ -73,9 +98,11 @@ class Postprocess(TrainModel):
         self.stochastic_plot_id = stochastic_plot_id
         self.args = args
         self.checkpoint = checkpoint
-        if not os.path.isfile(self.checkpoint+".meta"): 
+        if not os.path.isfile(self.checkpoint + ".meta"):
             _ = check_dir(self.checkpoint)
-            self.checkpoint += "/"          # trick to handle checkpoint-directory and file simulataneously
+            self.checkpoint += (
+                "/"  # trick to handle checkpoint-directory and file simulataneously
+            )
         self.clim_path = clim_path
         self.run_mode = run_mode
         self.data_mode = data_mode
@@ -87,18 +114,31 @@ class Postprocess(TrainModel):
         # configuration of basic evaluation
         self.eval_metrics = eval_metrics
         self.nboots_block = 1000
-        self.block_length = 7 * 24  # this corresponds to a block length of 7 days in case of hourly forecasts
-        if ltest: self.block_length = 1
+        self.block_length = (
+            7 * 24
+        )  # this corresponds to a block length of 7 days in case of hourly forecasts
+        if ltest:
+            self.block_length = 1
         # initialize evrything to get an executable Postprocess instance
         if args is not None:
-            self.save_args_to_option_json()     # create options.json in results directory
-        self.copy_data_model_json()             # copy over JSON-files from model directory
+            self.save_args_to_option_json()  # create options.json in results directory
+        self.copy_data_model_json()  # copy over JSON-files from model directory
         # get some parameters related to model and dataset
-        self.datasplit_dict, self.model_hparams_dict, self.dataset, self.model, self.input_dir_tfr = self.load_jsons()
+        (
+            self.datasplit_dict,
+            self.model_hparams_dict,
+            self.dataset,
+            self.model,
+            self.input_dir_tfr,
+        ) = self.load_jsons()
         self.model_hparams_dict_load = self.get_model_hparams_dict()
         # set input paths and forecast product dictionary
         self.input_dir, self.input_dir_pkl = self.get_input_dirs()
-        self.fcst_products = {self.model: "mfcst"} if lquick else {"persistence": "pfcst", self.model: "mfcst"}
+        self.fcst_products = (
+            {self.model: "mfcst"}
+            if lquick
+            else {"persistence": "pfcst", self.model: "mfcst"}
+        )
         # correct number of stochastic samples if necessary
         self.check_num_stochastic_samples()
         # get metadata
@@ -112,9 +152,15 @@ class Postprocess(TrainModel):
         # setup test dataset and model
         self.test_dataset, self.num_samples_per_epoch = self.setup_dataset()
         if lquick and self.test_dataset.shuffled:
-            self.num_samples_per_epoch = Postprocess.reduce_samples(self.num_samples_per_epoch, frac_data)
+            self.num_samples_per_epoch = Postprocess.reduce_samples(
+                self.num_samples_per_epoch, frac_data
+            )
         # self.num_samples_per_epoch = 100              # reduced number of epoch samples -> useful for testing
-        self.sequence_length, self.context_frames, self.future_length = self.get_data_params()
+        (
+            self.sequence_length,
+            self.context_frames,
+            self.future_length,
+        ) = self.get_data_params()
         self.inputs, self.input_ts = self.make_test_dataset_iterator()
         self.data_clim = None
         if "acc" in eval_metrics:
@@ -134,7 +180,9 @@ class Postprocess(TrainModel):
         method = Postprocess.get_input_dirs.__name__
 
         if not hasattr(self, "input_dir_tfr"):
-            raise AttributeError("Attribute input_dir_tfr is still missing.".format(method))
+            raise AttributeError(
+                "Attribute input_dir_tfr is still missing.".format(method)
+            )
 
         _ = check_dir(self.input_dir_tfr)
 
@@ -167,24 +215,38 @@ class Postprocess(TrainModel):
         model_dd_js = os.path.join(model_outdir, "data_split.json")
 
         if os.path.isfile(model_opt_js):
-            shutil.copy(model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json"))
+            shutil.copy(
+                model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json")
+            )
         else:
-            raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_opt_js))
+            raise FileNotFoundError(
+                "%{0}: The file {1} does not exist".format(method_name, model_opt_js)
+            )
 
         if os.path.isfile(model_ds_js):
-            shutil.copy(model_ds_js, os.path.join(self.results_dir, "dataset_hparams.json"))
+            shutil.copy(
+                model_ds_js, os.path.join(self.results_dir, "dataset_hparams.json")
+            )
         else:
-            raise FileNotFoundError("%{0}: the file {1} does not exist".format(method_name, model_ds_js))
+            raise FileNotFoundError(
+                "%{0}: the file {1} does not exist".format(method_name, model_ds_js)
+            )
 
         if os.path.isfile(model_hp_js):
-            shutil.copy(model_hp_js, os.path.join(self.results_dir, "model_hparams.json"))
+            shutil.copy(
+                model_hp_js, os.path.join(self.results_dir, "model_hparams.json")
+            )
         else:
-            raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_hp_js))
+            raise FileNotFoundError(
+                "%{0}: The file {1} does not exist".format(method_name, model_hp_js)
+            )
 
         if os.path.isfile(model_dd_js):
             shutil.copy(model_dd_js, os.path.join(self.results_dir, "data_split.json"))
         else:
-            raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_dd_js))
+            raise FileNotFoundError(
+                "%{0}: The file {1} does not exist".format(method_name, model_dd_js)
+            )
 
     def load_jsons(self):
         """
@@ -204,16 +266,25 @@ class Postprocess(TrainModel):
 
         # sanity checks on the JSON-files
         if not os.path.isfile(datasplit_dict):
-            raise FileNotFoundError("%{0}: The file data_split.json is missing in {1}".format(method_name,
-                                                                                             self.results_dir))
+            raise FileNotFoundError(
+                "%{0}: The file data_split.json is missing in {1}".format(
+                    method_name, self.results_dir
+                )
+            )
 
         if not os.path.isfile(model_hparams_dict):
-            raise FileNotFoundError("%{0}: The file model_hparams.json is missing in {1}".format(method_name,
-                                                                                                 self.results_dir))
+            raise FileNotFoundError(
+                "%{0}: The file model_hparams.json is missing in {1}".format(
+                    method_name, self.results_dir
+                )
+            )
 
         if not os.path.isfile(checkpoint_opt_dict):
-            raise FileNotFoundError("%{0}: The file options_checkpoints.json is missing in {1}"
-                                    .format(method_name, self.results_dir))
+            raise FileNotFoundError(
+                "%{0}: The file options_checkpoints.json is missing in {1}".format(
+                    method_name, self.results_dir
+                )
+            )
         # retrieve some data from options_checkpoints.json
         try:
             with open(checkpoint_opt_dict) as f:
@@ -222,8 +293,11 @@ class Postprocess(TrainModel):
                 model = options_checkpoint["model"]
                 input_dir_tfr = options_checkpoint["input_dir"]
         except Exception as err:
-            print("%{0}: Something went wrong when reading the checkpoint-file '{1}'".format(method_name,
-                                                                                             checkpoint_opt_dict))
+            print(
+                "%{0}: Something went wrong when reading the checkpoint-file '{1}'".format(
+                    method_name, checkpoint_opt_dict
+                )
+            )
             raise err
 
         return datasplit_dict, model_hparams_dict, dataset, model, input_dir_tfr
@@ -234,43 +308,55 @@ class Postprocess(TrainModel):
 
         # some sanity checks
         if self.input_dir is None:
-            raise AttributeError("%{0}: input_dir-attribute is still None".format(method_name))
+            raise AttributeError(
+                "%{0}: input_dir-attribute is still None".foFrmat(method_name)
+            )
 
         metadata_fl = os.path.join(self.input_dir, "metadata.json")
 
         if not os.path.isfile(metadata_fl):
-            raise FileNotFoundError("%{0}: Could not find metadata JSON-file under '{1}'".format(method_name,
-                                                                                                 self.input_dir))
+            raise FileNotFoundError(
+                "%{0}: Could not find metadata JSON-file under '{1}'".format(
+                    method_name, self.input_dir
+                )
+            )
 
         try:
             md_instance = MetaData(json_file=metadata_fl)
         except Exception as err:
-            print("%{0}: Something went wrong when getting metadata from file '{1}'".format(method_name, metadata_fl))
+            print(
+                "%{0}: Something went wrong when getting metadata from file '{1}'".format(
+                    method_name, metadata_fl
+                )
+            )
             raise err
 
         return md_instance
 
-    def load_climdata(self,data_clim_path="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/climatology_t2m_1991-2020.nc",
-                            var="var167"):
+    def load_climdata(
+        self,
+        data_clim_path="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/climatology_t2m_1991-2020.nc",
+        var="var167",
+    ):
         """
         params:data_cli_path : str, the full path to the climatology file
-        params:var           : str, the variable name 
-        
+        params:var           : str, the variable name
+
         """
         data = xr.open_dataset(data_clim_path)
         dt_clim = data[var]
 
-        clim_lon = dt_clim['lon'].data
-        clim_lat = dt_clim['lat'].data
-        
+        clim_lon = dt_clim["lon"].data
+        clim_lat = dt_clim["lat"].data
+
         meta_lon_loc = np.zeros((len(clim_lon)), dtype=bool)
         for i in range(len(clim_lon)):
-            if np.round(clim_lon[i],1) in self.lons.data:
+            if np.round(clim_lon[i], 1) in self.lons.data:
                 meta_lon_loc[i] = True
 
         meta_lat_loc = np.zeros((len(clim_lat)), dtype=bool)
         for i in range(len(clim_lat)):
-            if np.round(clim_lat[i],1) in self.lats.data:
+            if np.round(clim_lat[i], 1) in self.lats.data:
                 meta_lat_loc[i] = True
 
         # get the coordinates of the data after running CDO
@@ -279,25 +365,33 @@ class Postprocess(TrainModel):
         # modify it our needs
         coords_new = dict(coords)
         coords_new.pop("time")
-        coords_new["month"] = np.arange(1, 13) 
+        coords_new["month"] = np.arange(1, 13)
         coords_new["hour"] = np.arange(0, 24)
         # initialize a new data array with explicit dimensions for month and hour
-        data_clim_new = xr.DataArray(np.full((12, 24, nlat, nlon), np.nan), coords=coords_new,
-                                     dims=["month", "hour", "lat", "lon"])
+        data_clim_new = xr.DataArray(
+            np.full((12, 24, nlat, nlon), np.nan),
+            coords=coords_new,
+            dims=["month", "hour", "lat", "lon"],
+        )
         # do the reorganization
-        for month in np.arange(1, 13): 
-            data_clim_new.loc[dict(month=month)]=dt_clim.sel(time=dt_clim["time.month"]==month)
+        for month in np.arange(1, 13):
+            data_clim_new.loc[dict(month=month)] = dt_clim.sel(
+                time=dt_clim["time.month"] == month
+            )
+
+        self.data_clim = data_clim_new[dict(lon=meta_lon_loc, lat=meta_lat_loc)]
 
-        self.data_clim = data_clim_new[dict(lon=meta_lon_loc,lat=meta_lat_loc)]
-         
     def setup_dataset(self):
         """
         setup the test dataset instance
         :return test_dataset: the test dataset instance
         """
         VideoDataset = datasets.get_dataset_class(self.dataset)
-        test_dataset = VideoDataset(input_dir=self.input_dir_tfr, mode=self.data_mode,
-                                    datasplit_config=self.datasplit_dict)
+        test_dataset = VideoDataset(
+            input_dir=self.input_dir_tfr,
+            mode=self.data_mode,
+            datasplit_config=self.datasplit_dict,
+        )
         nsamples = test_dataset.num_examples_per_epoch()
 
         return test_dataset, nsamples
@@ -310,18 +404,25 @@ class Postprocess(TrainModel):
         method = Postprocess.get_data_params.__name__
 
         if not hasattr(self, "model_hparams_dict_load"):
-            raise AttributeError("%{0}: Attribute model_hparams_dict_load is still unset.".format(method))
+            raise AttributeError(
+                "%{0}: Attribute model_hparams_dict_load is still unset.".format(method)
+            )
 
         try:
             context_frames = self.model_hparams_dict_load["context_frames"]
             sequence_length = self.model_hparams_dict_load["sequence_length"]
         except Exception as err:
-            print("%{0}: Could not retrieve context_frames and sequence_length from model_hparams_dict_load-attribute"
-                  .format(method))
+            print(
+                "%{0}: Could not retrieve context_frames and sequence_length from model_hparams_dict_load-attribute".format(
+                    method
+                )
+            )
             raise err
         future_length = sequence_length - context_frames
         if future_length <= 0:
-            raise ValueError("Calculated future_length must be greater than zero.".format(method))
+            raise ValueError(
+                "Calculated future_length must be greater than zero.".format(method)
+            )
 
         return sequence_length, context_frames, future_length
 
@@ -333,11 +434,15 @@ class Postprocess(TrainModel):
         method = Postprocess.set_stat_file.__name__
 
         if not hasattr(self, "input_dir"):
-            raise AttributeError("%{0}: Attribute input_dir is still unset".format(method))
+            raise AttributeError(
+                "%{0}: Attribute input_dir is still unset".format(method)
+            )
 
         stat_fl = os.path.join(self.input_dir, "statistics.json")
         if not os.path.isfile(stat_fl):
-            raise FileNotFoundError("%{0}: Cannot find statistics JSON-file '{1}'".format(method, stat_fl))
+            raise FileNotFoundError(
+                "%{0}: Cannot find statistics JSON-file '{1}'".format(method, stat_fl)
+            )
 
         return stat_fl
 
@@ -350,8 +455,10 @@ class Postprocess(TrainModel):
 
         if not hasattr(self, "model"):
             raise AttributeError("%{0}: Attribute model is still unset.".format(method))
-        cond_quantile_vars = ["{0}_{1}_fcst".format(self.vars_in[self.channel], self.model),
-                              "{0}_ref".format(self.vars_in[self.channel])]
+        cond_quantile_vars = [
+            "{0}_{1}_fcst".format(self.vars_in[self.channel], self.model),
+            "{0}_ref".format(self.vars_in[self.channel]),
+        ]
 
         return cond_quantile_vars
 
@@ -362,18 +469,23 @@ class Postprocess(TrainModel):
         method = Postprocess.make_test_dataset_iterator.__name__
 
         if not hasattr(self, "test_dataset"):
-            raise AttributeError("%{0}: Attribute test_dataset is still unset".format(method))
+            raise AttributeError(
+                "%{0}: Attribute test_dataset is still unset".format(method)
+            )
 
         if not hasattr(self, "batch_size"):
-            raise AttributeError("%{0}: Attribute batch_sie is still unset".format(method))
+            raise AttributeError(
+                "%{0}: Attribute batch_sie is still unset".format(method)
+            )
 
         test_tf_dataset = self.test_dataset.make_dataset(self.batch_size)
         test_iterator = test_tf_dataset.make_one_shot_iterator()
         # The `Iterator.string_handle()` method returns a tensor that can be evaluated
         # and used to feed the `handle` placeholder.
         test_handle = test_iterator.string_handle()
-        dataset_iterator = tf.data.Iterator.from_string_handle(test_handle, test_tf_dataset.output_types,
-                                                               test_tf_dataset.output_shapes)
+        dataset_iterator = tf.data.Iterator.from_string_handle(
+            test_handle, test_tf_dataset.output_types, test_tf_dataset.output_shapes
+        )
         input_iter = dataset_iterator.get_next()
         ts_iter = input_iter["T_start"]
 
@@ -389,11 +501,15 @@ class Postprocess(TrainModel):
         if not hasattr(self, "model"):
             raise AttributeError("%{0}: Attribute model is still unset".format(method))
         if not hasattr(self, "num_stochastic_samples"):
-            raise AttributeError("%{0}: Attribute num_stochastic_samples is still unset".format(method))
+            raise AttributeError(
+                "%{0}: Attribute num_stochastic_samples is still unset".format(method)
+            )
 
         if np.any(self.model in ["convLSTM", "test_model", "mcnet"]):
             if self.num_stochastic_samples > 1:
-                print("Number of samples for deterministic model cannot be larger than 1. Higher values are ignored.")
+                print(
+                    "Number of samples for deterministic model cannot be larger than 1. Higher values are ignored."
+                )
             self.num_stochastic_samples = 1
 
     # the run-factory
@@ -416,101 +532,149 @@ class Postprocess(TrainModel):
         self.restore(self.sess, self.checkpoint)
         # Loop for samples
         self.sample_ind = 0
-        self.prst_metric_all = []  # store evaluation metrics of persistence forecast (shape [future_len])
-        self.fcst_metric_all = []  # store evaluation metric of stochastic forecasts (shape [nstoch, batch, future_len])
+        self.prst_metric_all = (
+            []
+        )  # store evaluation metrics of persistence forecast (shape [future_len])
+        self.fcst_metric_all = (
+            []
+        )  # store evaluation metric of stochastic forecasts (shape [nstoch, batch, future_len])
         while self.sample_ind < self.num_samples_per_epoch:
             if self.num_samples_per_epoch < self.sample_ind:
                 break
             else:
                 # run the inputs and plot each sequence images
-                self.input_results, self.input_images_denorm_all, self.t_starts = self.get_input_data_per_batch()
-
-            feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()}
+                (
+                    self.input_results,
+                    self.input_images_denorm_all,
+                    self.t_starts,
+                ) = self.get_input_data_per_batch()
+
+            feed_dict = {
+                input_ph: self.input_results[name]
+                for name, input_ph in self.inputs.items()
+            }
             gen_loss_stochastic_batch = []  # [stochastic_ind,future_length]
-            gen_images_stochastic = []  # [stochastic_ind,batch_size,seq_len,lat,lon,channels]
+            gen_images_stochastic = (
+                []
+            )  # [stochastic_ind,batch_size,seq_len,lat,lon,channels]
             # Loop for stochastics
             for stochastic_sample_ind in range(self.num_stochastic_samples):
                 print("stochastic_sample_ind:", stochastic_sample_ind)
                 # return [batchsize,seq_len,lat,lon,channel]
-                gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
+                gen_images = self.sess.run(
+                    self.video_model.outputs["gen_images"], feed_dict=feed_dict
+                )
                 # The generate images seq_len should be sequence_len -1, since the last one is
                 # not used for comparing with groud truth
                 assert gen_images.shape[1] == self.sequence_length - 1
                 gen_images_per_batch = []
                 if stochastic_sample_ind == 0:
-                    persistent_images_per_batch = []  # [batch_size,seq_len,lat,lon,channel]
+                    persistent_images_per_batch = (
+                        []
+                    )  # [batch_size,seq_len,lat,lon,channel]
                     ts_batch = []
                 for i in range(self.batch_size):
                     # generate time stamps for sequences only once, since they are the same for all ensemble members
                     if stochastic_sample_ind == 0:
-                        self.ts = Postprocess.generate_seq_timestamps(self.t_starts[i], len_seq=self.sequence_length)
+                        self.ts = Postprocess.generate_seq_timestamps(
+                            self.t_starts[i], len_seq=self.sequence_length
+                        )
                         init_date_str = self.ts[0].strftime("%Y%m%d%H")
                         ts_batch.append(init_date_str)
                         # get persistence_images
-                        self.persistence_images, self.ts_persistence = Postprocess.get_persistence(self.ts,
-                                                                                                   self.input_dir_pkl)
+                        (
+                            self.persistence_images,
+                            self.ts_persistence,
+                        ) = Postprocess.get_persistence(self.ts, self.input_dir_pkl)
                         persistent_images_per_batch.append(self.persistence_images)
                         assert len(np.array(persistent_images_per_batch).shape) == 5
                         self.plot_persistence_images()
 
                     # Denormalized data for generate
                     gen_images_ = gen_images[i]
-                    self.gen_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl, gen_images_,
-                                                                                    self.vars_in)
+                    self.gen_images_denorm = Postprocess.denorm_images_all_channels(
+                        self.stat_fl, gen_images_, self.vars_in
+                    )
                     gen_images_per_batch.append(self.gen_images_denorm)
                     assert len(np.array(gen_images_per_batch).shape) == 5
                     # only plot when the first stochastic ind otherwise too many plots would be created
                     # only plot the stochastic results of user-defined ind
-                    self.plot_generate_images(stochastic_sample_ind, self.stochastic_plot_id)
+                    self.plot_generate_images(
+                        stochastic_sample_ind, self.stochastic_plot_id
+                    )
                 # calculate the persistnet error per batch
                 if stochastic_sample_ind == 0:
-                    persistent_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all,
-                                                                                       persistent_images_per_batch,
-                                                                                       self.future_length,
-                                                                                       self.context_frames,
-                                                                                       matric="mse", channel=0)
+                    persistent_loss_per_batch = Postprocess.calculate_metrics_by_batch(
+                        self.input_images_denorm_all,
+                        persistent_images_per_batch,
+                        self.future_length,
+                        self.context_frames,
+                        matric="mse",
+                        channel=0,
+                    )
                     self.prst_metric_all.append(persistent_loss_per_batch)
 
                 # calculate the gen_images_per_batch error
-                gen_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all,
-                                                                            gen_images_per_batch, self.future_length,
-                                                                            self.context_frames,
-                                                                            matric="mse", channel=0)
+                gen_loss_per_batch = Postprocess.calculate_metrics_by_batch(
+                    self.input_images_denorm_all,
+                    gen_images_per_batch,
+                    self.future_length,
+                    self.context_frames,
+                    matric="mse",
+                    channel=0,
+                )
                 gen_loss_stochastic_batch.append(
-                    gen_loss_per_batch)  # self.gen_images_stochastic[stochastic,future_length]
-                print("gen_images_per_batch shape:", np.array(gen_images_per_batch).shape)
+                    gen_loss_per_batch
+                )  # self.gen_images_stochastic[stochastic,future_length]
+                print(
+                    "gen_images_per_batch shape:", np.array(gen_images_per_batch).shape
+                )
                 gen_images_stochastic.append(
-                    gen_images_per_batch)  # [stochastic,batch_size, seq_len, lat, lon, channel]
+                    gen_images_per_batch
+                )  # [stochastic,batch_size, seq_len, lat, lon, channel]
 
                 # Switch the 0 and 1 position
                 print("before transpose:", np.array(gen_images_stochastic).shape)
-            gen_images_stochastic = np.transpose(np.array(gen_images_stochastic), (
-                1, 0, 2, 3, 4, 5))  # [batch_size, stochastic, seq_len, lat, lon, chanel]
+            gen_images_stochastic = np.transpose(
+                np.array(gen_images_stochastic), (1, 0, 2, 3, 4, 5)
+            )  # [batch_size, stochastic, seq_len, lat, lon, chanel]
             Postprocess.check_gen_images_stochastic_shape(gen_images_stochastic)
             assert len(gen_images_stochastic.shape) == 6
-            assert np.array(gen_images_stochastic).shape[1] == self.num_stochastic_samples
+            assert (
+                np.array(gen_images_stochastic).shape[1] == self.num_stochastic_samples
+            )
 
             self.fcst_metric_all.append(
-                gen_loss_stochastic_batch)  # [samples/batch_size,stochastic,future_length]
+                gen_loss_stochastic_batch
+            )  # [samples/batch_size,stochastic,future_length]
             # save input and stochastic generate images to netcdf file
             # For each prediction (either deterministic or ensemble) we create one netCDF file.
             for batch_id in range(self.batch_size):
-                self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id],
-                                                                   persistent_images_per_batch[batch_id],
-                                                                   np.array(gen_images_stochastic)[batch_id],
-                                                                   fl_name="vfp_date_{}_sample_ind_{}.nc"
-                                                                   .format(ts_batch[batch_id],
-                                                                           self.sample_ind + batch_id))
+                self.save_to_netcdf_for_stochastic_generate_images(
+                    self.input_images_denorm_all[batch_id],
+                    persistent_images_per_batch[batch_id],
+                    np.array(gen_images_stochastic)[batch_id],
+                    fl_name="vfp_date_{}_sample_ind_{}.nc".format(
+                        ts_batch[batch_id], self.sample_ind + batch_id
+                    ),
+                )
 
             self.sample_ind += self.batch_size
 
-        self.persistent_loss_all_batches = np.mean(np.array(self.persistent_loss_all_batches), axis=0)
-        self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0)
+        self.persistent_loss_all_batches = np.mean(
+            np.array(self.persistent_loss_all_batches), axis=0
+        )
+        self.stochastic_loss_all_batches = np.mean(
+            np.array(self.stochastic_loss_all_batches), axis=0
+        )
         assert len(np.array(self.persistent_loss_all_batches).shape) == 1
         assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length
 
         assert len(np.array(self.stochastic_loss_all_batches).shape) == 2
-        assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples
+        assert (
+            np.array(self.stochastic_loss_all_batches).shape[0]
+            == self.num_stochastic_samples
+        )
 
     def run_deterministic(self):
         """
@@ -526,57 +690,107 @@ class Postprocess(TrainModel):
         sample_ind = 0
         nsamples = self.num_samples_per_epoch
         # initialize xarray datasets
-        eval_metric_ds = Postprocess.init_metric_ds(self.fcst_products, self.eval_metrics, self.vars_in[self.channel],
-                                                    nsamples, self.future_length)
+        eval_metric_ds = Postprocess.init_metric_ds(
+            self.fcst_products,
+            self.eval_metrics,
+            self.vars_in[self.channel],
+            nsamples,
+            self.future_length,
+        )
         cond_quantiple_ds = None
 
         while sample_ind < nsamples:
             # get normalized and denormalized input data
-            input_results, input_images_denorm, t_starts = self.get_input_data_per_batch(self.inputs)
+            (
+                input_results,
+                input_images_denorm,
+                t_starts,
+            ) = self.get_input_data_per_batch(self.inputs)
             # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel]
-            print("%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(method, self.batch_size,
-                                                                                                  sample_ind))
-            feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()}
-            gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
+            print(
+                "%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(
+                    method, self.batch_size, sample_ind
+                )
+            )
+            feed_dict = {
+                input_ph: input_results[name] for name, input_ph in self.inputs.items()
+            }
+            gen_images = self.sess.run(
+                self.video_model.outputs["gen_images"], feed_dict=feed_dict
+            )
 
             # sanity check on length of forecast sequence
-            assert gen_images.shape[1] == self.sequence_length - 1, \
-                "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(method)
+            assert (
+                gen_images.shape[1] == self.sequence_length - 1
+            ), "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(
+                method
+            )
             # denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method)
-            gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls,
-                                                                norm_method="minmax")
+            gen_images_denorm = self.denorm_images_all_channels(
+                gen_images, self.vars_in, self.norm_cls, norm_method="minmax"
+            )
             # store data into datset & get number of samples (may differ from batch_size at the end of the test dataset)
             times_0, init_times = self.get_init_time(t_starts)
-            batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times)
+            batch_ds = self.create_dataset(
+                input_images_denorm, gen_images_denorm, init_times
+            )
             nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind)
             batch_ds = batch_ds.isel(init_time=slice(0, nbs))
 
             # run over mini-batch only if quick evaluation is NOT active
             for i in np.arange(0 if self.lquick else nbs):
-                print("%{0}: Process mini-batch sample {1:d}/{2:d}".format(method, i+1, nbs))
+                print(
+                    "%{0}: Process mini-batch sample {1:d}/{2:d}".format(
+                        method, i + 1, nbs
+                    )
+                )
                 # work-around to make use of get_persistence_forecast_per_sample-method
-                times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime()
+                times_seq = (
+                    pd.date_range(
+                        times_0[i], periods=int(self.sequence_length), freq="h"
+                    )
+                ).to_pydatetime()
                 # get persistence forecast for sequences at hand and write to dataset
-                persistence_seq, _ = Postprocess.get_persistence(times_seq, self.input_dir_pkl)
+                persistence_seq, _ = Postprocess.get_persistence(
+                    times_seq, self.input_dir_pkl
+                )
                 for ivar, var in enumerate(self.vars_in):
-                    batch_ds["{0}_persistence_fcst".format(var)].loc[dict(init_time=init_times[i])] = \
-                            persistence_seq[self.context_frames-1:, :, :, ivar]
+                    batch_ds["{0}_persistence_fcst".format(var)].loc[
+                        dict(init_time=init_times[i])
+                    ] = persistence_seq[self.context_frames - 1 :, :, :, ivar]
 
                 # save sequences to netcdf-file and track initial time
-                nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc"
-                                        .format(pd.to_datetime(init_times[i]).strftime("%Y%m%d%H"), sample_ind + i))
-                
+                nc_fname = os.path.join(
+                    self.results_dir,
+                    "vfp_date_{0}_sample_ind_{1:d}.nc".format(
+                        pd.to_datetime(init_times[i]).strftime("%Y%m%d%H"),
+                        sample_ind + i,
+                    ),
+                )
+
                 if os.path.exists(nc_fname):
-                    print("%{0}: The file '{1}' already exists and is therefore skipped".format(method, nc_fname))
+                    print(
+                        "%{0}: The file '{1}' already exists and is therefore skipped".format(
+                            method, nc_fname
+                        )
+                    )
                 else:
                     self.save_ds_to_netcdf(batch_ds.isel(init_time=i), nc_fname)
                 # end of batch-loop
             # write evaluation metric to corresponding dataset and sa
-            eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind,
-                                                          self.vars_in[self.channel])
-            if not self.lquick:             # conditional quantiles are not evaluated for quick evaluation
-                cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars,
-                                                          "init_time", dtype=np.float16)
+            eval_metric_ds = self.populate_eval_metric_ds(
+                eval_metric_ds, batch_ds, sample_ind, self.vars_in[self.channel]
+            )
+            if (
+                not self.lquick
+            ):  # conditional quantiles are not evaluated for quick evaluation
+                cond_quantiple_ds = Postprocess.append_ds(
+                    batch_ds,
+                    cond_quantiple_ds,
+                    self.cond_quantile_vars,
+                    "init_time",
+                    dtype=np.float16,
+                )
             # ... and increment sample_ind
             sample_ind += self.batch_size
             # end of while-loop for samples
@@ -584,7 +798,7 @@ class Postprocess(TrainModel):
         self.eval_metrics_ds = eval_metric_ds
         self.cond_quantiple_ds = cond_quantiple_ds
         self.sess.close()
-             
+
     # all methods of the run factory
     def init_session(self):
         """
@@ -617,15 +831,23 @@ class Postprocess(TrainModel):
         t_starts = input_results["T_start"]
         if self.norm_cls is None:
             if self.stat_fl is None:
-                raise AttributeError("%{0}: Attribute stat_fl is not initialized yet.".format(method))
-            self.norm_cls = Postprocess.get_norm(self.vars_in, self.stat_fl, norm_method)
+                raise AttributeError(
+                    "%{0}: Attribute stat_fl is not initialized yet.".format(method)
+                )
+            self.norm_cls = Postprocess.get_norm(
+                self.vars_in, self.stat_fl, norm_method
+            )
 
         # sanity check on input sequence
-        assert np.ndim(input_images) == 5, "%{0}: Input sequence of mini-batch does not have five dimensions."\
-                                           .format(method)
+        assert (
+            np.ndim(input_images) == 5
+        ), "%{0}: Input sequence of mini-batch does not have five dimensions.".format(
+            method
+        )
 
-        input_images_denorm = Postprocess.denorm_images_all_channels(input_images, self.vars_in, self.norm_cls,
-                                                                     norm_method=norm_method)
+        input_images_denorm = Postprocess.denorm_images_all_channels(
+            input_images, self.vars_in, self.norm_cls, norm_method=norm_method
+        )
 
         return input_results, input_images_denorm, t_starts
 
@@ -639,15 +861,24 @@ class Postprocess(TrainModel):
 
         t_starts = np.squeeze(np.asarray(t_starts))
         if not np.ndim(t_starts) == 1:
-            raise ValueError("%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H"
-                             .format(method))
+            raise ValueError(
+                "%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H".format(
+                    method
+                )
+            )
         for i, t_start in enumerate(t_starts):
             try:
-                seq_ts = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H"), periods=self.context_frames,
-                                       freq="h")
+                seq_ts = pd.date_range(
+                    dt.datetime.strptime(str(t_start), "%Y%m%d%H"),
+                    periods=self.context_frames,
+                    freq="h",
+                )
             except Exception as err:
-                print("%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'".
-                      format(method, str(t_start)))
+                print(
+                    "%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'".format(
+                        method, str(t_start)
+                    )
+                )
                 raise err
             if i == 0:
                 ts_all = np.expand_dims(seq_ts, axis=0)
@@ -672,7 +903,9 @@ class Postprocess(TrainModel):
 
         # dictionary of implemented evaluation metrics
         dims = ["lat", "lon"]
-        eval_metrics_func = [Scores(metric, dims).score_func for metric in self.eval_metrics]
+        eval_metrics_func = [
+            Scores(metric, dims).score_func for metric in self.eval_metrics
+        ]
         varname_ref = "{0}_ref".format(varname)
         # reset init-time coordinate of metric_ds in place and get indices for slicing
         ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch)
@@ -685,12 +918,14 @@ class Postprocess(TrainModel):
                 metric_name = "{0}_{1}_{2}".format(varname, fcst_prod, eval_metric)
                 varname_fcst = "{0}_{1}_fcst".format(varname, fcst_prod)
                 dict_ind = dict(init_time=data_ds["init_time"])
-                metric_ds[metric_name].loc[dict_ind] = eval_metrics_func[imetric](data_fcst=data_ds[varname_fcst],
-                                                                                  data_ref=data_ds[varname_ref],
-                                                                                  data_clim=self.data_clim)
+                metric_ds[metric_name].loc[dict_ind] = eval_metrics_func[imetric](
+                    data_fcst=data_ds[varname_fcst],
+                    data_ref=data_ds[varname_ref],
+                    data_clim=self.data_clim,
+                )
             # end of metric-loop
         # end of forecast product-loop
-        
+
         return metric_ds
 
     def add_ensemble_dim(self):
@@ -698,8 +933,12 @@ class Postprocess(TrainModel):
         Expands dimensions of loss-arrays by dummy ensemble-dimension (used for deterministic forecasts only)
         :return:
         """
-        self.stochastic_loss_all_batches = np.expand_dims(self.fcst_mse_avg_batches, axis=0)  # [1,future_lenght]
-        self.stochastic_loss_all_batches_psnr = np.expand_dims(self.fcst_psnr_avg_batches, axis=0)  # [1,future_lenght]
+        self.stochastic_loss_all_batches = np.expand_dims(
+            self.fcst_mse_avg_batches, axis=0
+        )  # [1,future_lenght]
+        self.stochastic_loss_all_batches_psnr = np.expand_dims(
+            self.fcst_psnr_avg_batches, axis=0
+        )  # [1,future_lenght]
 
     def create_dataset(self, input_seq, fcst_seq, ts_ini):
         """
@@ -714,58 +953,118 @@ class Postprocess(TrainModel):
         method = Postprocess.create_dataset.__name__
 
         # auxiliary variables for temporal dimensions
-        seq_hours = np.arange(self.sequence_length) - (self.context_frames-1)
+        seq_hours = np.arange(self.sequence_length) - (self.context_frames - 1)
         # some sanity checks
-        assert np.shape(ts_ini)[0] == self.batch_size,\
-            "%{0}: Inconsistent number of sequence start times ({1:d}) and batch size ({2:d})"\
-            .format(method, np.shape(ts_ini)[0], self.batch_size)
+        assert (
+            np.shape(ts_ini)[0] == self.batch_size
+        ), "%{0}: Inconsistent number of sequence start times ({1:d}) and batch size ({2:d})".format(
+            method, np.shape(ts_ini)[0], self.batch_size
+        )
 
         # turn input and forecast sequences to Data Arrays to ease indexing
         try:
-            input_seq = xr.DataArray(input_seq, coords={"init_time": ts_ini, "fcst_hour": seq_hours,
-                                                        "lat": self.lats, "lon": self.lons, "varname": self.vars_in},
-                                     dims=["init_time", "fcst_hour", "lat", "lon", "varname"])
+            input_seq = xr.DataArray(
+                input_seq,
+                coords={
+                    "init_time": ts_ini,
+                    "fcst_hour": seq_hours,
+                    "lat": self.lats,
+                    "lon": self.lons,
+                    "varname": self.vars_in,
+                },
+                dims=["init_time", "fcst_hour", "lat", "lon", "varname"],
+            )
         except Exception as err:
-            print("%{0}: Could not create Data Array for input sequence.".format(method))
+            print(
+                "%{0}: Could not create Data Array for input sequence.".format(method)
+            )
             raise err
 
         try:
-            fcst_seq = xr.DataArray(fcst_seq, coords={"init_time": ts_ini, "fcst_hour": seq_hours[1::],
-                                                      "lat": self.lats, "lon": self.lons, "varname": self.vars_in},
-                                    dims=["init_time", "fcst_hour", "lat", "lon", "varname"])
+            fcst_seq = xr.DataArray(
+                fcst_seq,
+                coords={
+                    "init_time": ts_ini,
+                    "fcst_hour": seq_hours[1::],
+                    "lat": self.lats,
+                    "lon": self.lons,
+                    "varname": self.vars_in,
+                },
+                dims=["init_time", "fcst_hour", "lat", "lon", "varname"],
+            )
         except Exception as err:
-            print("%{0}: Could not create Data Array for forecast sequence.".format(method))
+            print(
+                "%{0}: Could not create Data Array for forecast sequence.".format(
+                    method
+                )
+            )
             raise err
 
         # Now create the dataset where the input sequence is splitted into input that served for creating the
         # forecast and into the the reference sequences (which can be compared to the forecast)
         # as where the persistence forecast is containing NaNs (must be generated later)
-        data_in_dict = dict([("{0}_in".format(var), input_seq.isel(fcst_hour=slice(None, self.context_frames),
-                                                                   varname=ivar)
-                                                             .rename({"fcst_hour": "in_hour"})
-                                                             .reset_coords(names="varname", drop=True))
-                             for ivar, var in enumerate(self.vars_in)])
+        data_in_dict = dict(
+            [
+                (
+                    "{0}_in".format(var),
+                    input_seq.isel(
+                        fcst_hour=slice(None, self.context_frames), varname=ivar
+                    )
+                    .rename({"fcst_hour": "in_hour"})
+                    .reset_coords(names="varname", drop=True),
+                )
+                for ivar, var in enumerate(self.vars_in)
+            ]
+        )
 
         # get shape of forecast data (one variable) -> required to initialize persistence forecast data
-        shape_fcst = np.shape(fcst_seq.isel(fcst_hour=slice(self.context_frames-1, None), varname=0)
-                                      .reset_coords(names="varname", drop=True))
-        data_ref_dict = dict([("{0}_ref".format(var), input_seq.isel(fcst_hour=slice(self.context_frames, None),
-                                                                     varname=ivar)
-                                                               .reset_coords(names="varname", drop=True))
-                              for ivar, var in enumerate(self.vars_in)])
-
-        data_mfcst_dict = dict([("{0}_{1}_fcst".format(var, self.model),
-                                 fcst_seq.isel(fcst_hour=slice(self.context_frames-1, None), varname=ivar)
-                                         .reset_coords(names="varname", drop=True))
-                                for ivar, var in enumerate(self.vars_in)])
+        shape_fcst = np.shape(
+            fcst_seq.isel(
+                fcst_hour=slice(self.context_frames - 1, None), varname=0
+            ).reset_coords(names="varname", drop=True)
+        )
+        data_ref_dict = dict(
+            [
+                (
+                    "{0}_ref".format(var),
+                    input_seq.isel(
+                        fcst_hour=slice(self.context_frames, None), varname=ivar
+                    ).reset_coords(names="varname", drop=True),
+                )
+                for ivar, var in enumerate(self.vars_in)
+            ]
+        )
+
+        data_mfcst_dict = dict(
+            [
+                (
+                    "{0}_{1}_fcst".format(var, self.model),
+                    fcst_seq.isel(
+                        fcst_hour=slice(self.context_frames - 1, None), varname=ivar
+                    ).reset_coords(names="varname", drop=True),
+                )
+                for ivar, var in enumerate(self.vars_in)
+            ]
+        )
 
         # fill persistence forecast variables with dummy data (to be populated later)
-        data_pfcst_dict = dict([("{0}_persistence_fcst".format(var), (["init_time", "fcst_hour", "lat", "lon"],
-                                                                      np.full(shape_fcst, np.nan)))
-                                for ivar, var in enumerate(self.vars_in)])
+        data_pfcst_dict = dict(
+            [
+                (
+                    "{0}_persistence_fcst".format(var),
+                    (
+                        ["init_time", "fcst_hour", "lat", "lon"],
+                        np.full(shape_fcst, np.nan),
+                    ),
+                )
+                for ivar, var in enumerate(self.vars_in)
+            ]
+        )
 
         # create the dataset
-        data_ds = xr.Dataset({**data_in_dict, **data_ref_dict, **data_mfcst_dict, **data_pfcst_dict})
+        data_ds = xr.Dataset(
+            {**data_in_dict, **data_ref_dict, **data_mfcst_dict, **data_pfcst_dict}
+        )
 
         return data_ds
 
@@ -777,11 +1076,16 @@ class Postprocess(TrainModel):
         method = Postprocess.handle_eval_metrics.__name__
 
         if self.eval_metrics_ds is None:
-            raise AttributeError("%{0}: Attribute with dataset of evaluation metrics is still None.".format(method))
+            raise AttributeError(
+                "%{0}: Attribute with dataset of evaluation metrics is still None.".format(
+                    method
+                )
+            )
 
         # perform bootstrapping on metric dataset
-        eval_metric_boot_ds = perform_block_bootstrap_metric(self.eval_metrics_ds, "init_time", self.block_length,
-                                                             self.nboots_block)
+        eval_metric_boot_ds = perform_block_bootstrap_metric(
+            self.eval_metrics_ds, "init_time", self.block_length, self.nboots_block
+        )
         # ... and merge into existing metric dataset
         self.eval_metrics_ds = xr.merge([self.eval_metrics_ds, eval_metric_boot_ds])
 
@@ -795,8 +1099,13 @@ class Postprocess(TrainModel):
         Postprocess.save_ds_to_netcdf(self.eval_metrics_ds, nc_fname)
 
         # also save averaged metrics to JSON-file and plot it for diagnosis
-        _ = plot_avg_eval_metrics(self.eval_metrics_ds, self.eval_metrics, self.fcst_products,
-                                  self.vars_in[self.channel], self.results_dir)
+        _ = plot_avg_eval_metrics(
+            self.eval_metrics_ds,
+            self.eval_metrics,
+            self.fcst_products,
+            self.vars_in[self.channel],
+            self.results_dir,
+        )
 
     def plot_example_forecasts(self, metric="mse", channel=0):
         """
@@ -811,20 +1120,30 @@ class Postprocess(TrainModel):
 
         metric_name = "{0}_{1}_{2}".format(self.vars_in[channel], self.model, metric)
         if not metric_name in self.eval_metrics_ds:
-            raise ValueError("%{0}: Cannot find requested evaluation metric '{1}'".format(method, metric_name) +
-                             " onto which selection of plotted forecast is done.")
+            raise ValueError(
+                "%{0}: Cannot find requested evaluation metric '{1}'".format(
+                    method, metric_name
+                )
+                + " onto which selection of plotted forecast is done."
+            )
         # average metric of interest and obtain quantiles incl. indices
         metric_mean = self.eval_metrics_ds[metric_name].mean(dim="fcst_hour")
-        quantiles = np.arange(0., 1.01, .1)
+        quantiles = np.arange(0.0, 1.01, 0.1)
         quantiles_val = metric_mean.quantile(quantiles, interpolation="nearest")
         quantiles_inds = self.get_matching_indices(metric_mean.values, quantiles_val)
 
         for i, ifcst in enumerate(quantiles_inds):
             date_init = pd.to_datetime(metric_mean.coords["init_time"][ifcst].data)
-            nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc"
-                                    .format(date_init.strftime("%Y%m%d%H"), ifcst))
+            nc_fname = os.path.join(
+                self.results_dir,
+                "vfp_date_{0}_sample_ind_{1:d}.nc".format(
+                    date_init.strftime("%Y%m%d%H"), ifcst
+                ),
+            )
             if not os.path.isfile(nc_fname):
-                raise FileNotFoundError("%{0}: Could not find requested file '{1}'".format(method, nc_fname))
+                raise FileNotFoundError(
+                    "%{0}: Could not find requested file '{1}'".format(method, nc_fname)
+                )
             else:
                 # get the data
                 varname = self.vars_in[channel]
@@ -834,9 +1153,15 @@ class Postprocess(TrainModel):
 
                 data_diff = data_fcst - data_ref
                 # name of plot
-                plt_fname_base = os.path.join(self.output_dir, "forecast_{0}_{1}_{2}_{3:d}percentile.png"
-                                              .format(varname, date_init.strftime("%Y%m%dT%H00"), metric,
-                                                      int(quantiles[i] * 100.)))
+                plt_fname_base = os.path.join(
+                    self.output_dir,
+                    "forecast_{0}_{1}_{2}_{3:d}percentile.png".format(
+                        varname,
+                        date_init.strftime("%Y%m%dT%H00"),
+                        metric,
+                        int(quantiles[i] * 100.0),
+                    ),
+                )
 
                 create_geo_contour_plot(data_fcst, data_diff, varname, plt_fname_base)
 
@@ -849,29 +1174,43 @@ class Postprocess(TrainModel):
         var_fcst = self.cond_quantile_vars[0]
         var_ref = self.cond_quantile_vars[1]
 
-        data_fcst = get_era5_varatts(self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name)
-        data_ref = get_era5_varatts(self.cond_quantiple_ds[var_ref], self.cond_quantiple_ds[var_ref].name)
+        data_fcst = get_era5_varatts(
+            self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name
+        )
+        data_ref = get_era5_varatts(
+            self.cond_quantiple_ds[var_ref], self.cond_quantiple_ds[var_ref].name
+        )
 
         # create plots
         fhhs = data_fcst.coords["fcst_hour"]
         for hh in fhhs:
             # calibration refinement factorization
-            plt_fname_cf = os.path.join(self.results_dir, "cond_quantile_{0}_{1}_fh{2:0d}_calibration_refinement.png"
-                                        .format(self.vars_in[self.channel], self.model, int(hh)))
-
-            quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh),
-                                                                           data_ref.sel(fcst_hour=hh),
-                                                                           factorization="calibration_refinement",
-                                                                           quantiles=(0.05, 0.5, 0.95))
+            plt_fname_cf = os.path.join(
+                self.results_dir,
+                "cond_quantile_{0}_{1}_fh{2:0d}_calibration_refinement.png".format(
+                    self.vars_in[self.channel], self.model, int(hh)
+                ),
+            )
+
+            quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(
+                data_fcst.sel(fcst_hour=hh),
+                data_ref.sel(fcst_hour=hh),
+                factorization="calibration_refinement",
+                quantiles=(0.05, 0.5, 0.95),
+            )
 
             plot_cond_quantile(quantile_panel_cf, cond_variable_cf, plt_fname_cf)
 
             # likelihood-base rate factorization
-            plt_fname_lbr = plt_fname_cf.replace("calibration_refinement", "likelihood-base_rate")
-            quantile_panel_lbr, cond_variable_lbr = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh),
-                                                                             data_ref.sel(fcst_hour=hh),
-                                                                             factorization="likelihood-base_rate",
-                                                                             quantiles=(0.05, 0.5, 0.95))
+            plt_fname_lbr = plt_fname_cf.replace(
+                "calibration_refinement", "likelihood-base_rate"
+            )
+            quantile_panel_lbr, cond_variable_lbr = calculate_cond_quantiles(
+                data_fcst.sel(fcst_hour=hh),
+                data_ref.sel(fcst_hour=hh),
+                factorization="likelihood-base_rate",
+                quantiles=(0.05, 0.5, 0.95),
+            )
 
             plot_cond_quantile(quantile_panel_lbr, cond_variable_lbr, plt_fname_lbr)
 
@@ -885,12 +1224,20 @@ class Postprocess(TrainModel):
         """
         method = Postprocess.reduce_samples.__name__
 
-        if frac_data <= 0. or frac_data >= 1.:
-            print("%{0}: frac_data is not within [0..1] and is therefore ignored.".format(method))
+        if frac_data <= 0.0 or frac_data >= 1.0:
+            print(
+                "%{0}: frac_data is not within [0..1] and is therefore ignored.".format(
+                    method
+                )
+            )
             return nsamples
         else:
-            nsamples_new = int(np.ceil(nsamples*frac_data))
-            print("%{0}: Sample size is reduced from {1:d} to {2:d}".format(method, int(nsamples), nsamples_new))
+            nsamples_new = int(np.ceil(nsamples * frac_data))
+            print(
+                "%{0}: Sample size is reduced from {1:d} to {2:d}".format(
+                    method, int(nsamples), nsamples_new
+                )
+            )
             return nsamples_new
 
     @staticmethod
@@ -905,7 +1252,11 @@ class Postprocess(TrainModel):
         method = Postprocess.clean_obj_attribute.__name__
 
         if not hasattr(obj, attr_name):
-            print("%{0}: Class attribute '{1}' does not exist. Nothing to do...".format(method, attr_name))
+            print(
+                "%{0}: Class attribute '{1}' does not exist. Nothing to do...".format(
+                    method, attr_name
+                )
+            )
         else:
             if lremove:
                 delattr(obj, attr_name)
@@ -927,7 +1278,9 @@ class Postprocess(TrainModel):
         method = Postprocess.get_norm.__name__
 
         if not isinstance(varnames, list):
-            raise ValueError("%{0}: varnames must be a list of variable names.".format(method))
+            raise ValueError(
+                "%{0}: varnames must be a list of variable names.".format(method)
+            )
 
         norm_cls = Norm_data(varnames)
         try:
@@ -935,12 +1288,18 @@ class Postprocess(TrainModel):
                 norm_cls.check_and_set_norm(json.load(js_file), norm_method)
             norm_cls = norm_cls
         except Exception as err:
-            print("%{0}: Could not handle statistics json-file '{1}'.".format(method, stat_fl))
+            print(
+                "%{0}: Could not handle statistics json-file '{1}'.".format(
+                    method, stat_fl
+                )
+            )
             raise err
         return norm_cls
 
     @staticmethod
-    def denorm_images_all_channels(image_sequence, varnames, norm, norm_method="minmax"):
+    def denorm_images_all_channels(
+        image_sequence, varnames, norm, norm_method="minmax"
+    ):
         """
         Denormalize data of all image channels
         :param image_sequence: list/array [batch, seq, lat, lon, channel] of images
@@ -955,15 +1314,23 @@ class Postprocess(TrainModel):
         image_sequence = np.array(image_sequence)
         # sanity checks
         if not isinstance(norm, Norm_data):
-            raise ValueError("%{0}: norm must be a normalization instance.".format(method))
+            raise ValueError(
+                "%{0}: norm must be a normalization instance.".format(method)
+            )
 
         if nvars != np.shape(image_sequence)[-1]:
-            raise ValueError("%{0}: Number of passed variable names ({1:d}) does not match number of channels ({2:d})"
-                             .format(method, nvars, np.shape(image_sequence)[-1]))
-
-        input_images_all_channles_denorm = [Postprocess.denorm_images(image_sequence, norm, {varname: c},
-                                                                      norm_method=norm_method)
-                                            for c, varname in enumerate(varnames)]
+            raise ValueError(
+                "%{0}: Number of passed variable names ({1:d}) does not match number of channels ({2:d})".format(
+                    method, nvars, np.shape(image_sequence)[-1]
+                )
+            )
+
+        input_images_all_channles_denorm = [
+            Postprocess.denorm_images(
+                image_sequence, norm, {varname: c}, norm_method=norm_method
+            )
+            for c, varname in enumerate(varnames)
+        ]
 
         input_images_denorm = np.stack(input_images_all_channles_denorm, axis=-1)
         return input_images_denorm
@@ -984,16 +1351,26 @@ class Postprocess(TrainModel):
             raise ValueError("%{0}: var_dict is not a dictionary.".format(method))
         else:
             if len(var_dict.keys()) > 1:
-                raise ValueError("%{0}: var_dict must contain one key only.".format(method))
+                raise ValueError(
+                    "%{0}: var_dict must contain one key only.".format(method)
+                )
             varname, channel = *var_dict.keys(), *var_dict.values()
 
         if not isinstance(norm, Norm_data):
-            raise ValueError("%{0}: norm must be a normalization instance.".format(method))
+            raise ValueError(
+                "%{0}: norm must be a normalization instance.".format(method)
+            )
 
         try:
-            input_images_denorm = norm.denorm_var(input_images[..., channel], varname, norm_method)
+            input_images_denorm = norm.denorm_var(
+                input_images[..., channel], varname, norm_method
+            )
         except Exception as err:
-            print("%{0}: Something went wrong when denormalizing image sequence. Inspect error-message!".format(method))
+            print(
+                "%{0}: Something went wrong when denormalizing image sequence. Inspect error-message!".format(
+                    method
+                )
+            )
             raise err
 
         return input_images_denorm
@@ -1024,7 +1401,9 @@ class Postprocess(TrainModel):
         """
         ts_persistence = []
         year_origin = ts[0].year
-        for t in range(len(ts)):  # Scarlet: this certainly can be made nicer with list comprehension
+        for t in range(
+            len(ts)
+        ):  # Scarlet: this certainly can be made nicer with list comprehension
             ts_temp = ts[t] - dt.timedelta(days=1)
             ts_persistence.append(ts_temp)
         t_persistence_start = ts_persistence[0]
@@ -1035,13 +1414,29 @@ class Postprocess(TrainModel):
         # only one pickle file is needed (all hours during the same month)
         if month_start == month_end:
             # Open files to search for the indizes of the corresponding time
-            time_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T'))
+            time_pickle = list(
+                Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_start, "T"
+                )
+            )
             # Open file to search for the correspoding meteorological fields
-            var_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X'))
+            var_pickle = list(
+                Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_start, "X"
+                )
+            )
 
             if year_origin != year_start:
-                time_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'T'))
-                var_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'X'))
+                time_origin_pickle = list(
+                    Postprocess.load_pickle_for_persistence(
+                        input_dir_pkl, year_origin, 12, "T"
+                    )
+                )
+                var_origin_pickle = list(
+                    Postprocess.load_pickle_for_persistence(
+                        input_dir_pkl, year_origin, 12, "X"
+                    )
+                )
                 time_pickle.extend(time_origin_pickle)
                 var_pickle.extend(var_origin_pickle)
 
@@ -1049,11 +1444,15 @@ class Postprocess(TrainModel):
             try:
                 ind = list(time_pickle).index(np.array(ts_persistence[0]))
             except Exception as err:
-                print("Please consider return Data preprocess step 1 to generate entire month data")
+                print(
+                    "Please consider return Data preprocess step 1 to generate entire month data"
+                )
                 raise err
 
-            var_persistence = np.array(var_pickle)[ind:ind + len(ts_persistence)]
-            time_persistence = np.array(time_pickle)[ind:ind + len(ts_persistence)].ravel()
+            var_persistence = np.array(var_pickle)[ind : ind + len(ts_persistence)]
+            time_persistence = np.array(time_pickle)[
+                ind : ind + len(ts_persistence)
+            ].ravel()
         # case that we need to derive the data from two pickle files (changing month during the forecast periode)
         else:
             t_persistence_first_m = []  # should hold dates of the first month
@@ -1067,35 +1466,69 @@ class Postprocess(TrainModel):
                     t_persistence_second_m.append(ts_persistence[t])
             if year_origin == year_start:
                 # Open files to search for the indizes of the corresponding time
-                time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T')
-                time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'T')
+                time_pickle_first = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_start, "T"
+                )
+                time_pickle_second = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_end, "T"
+                )
 
                 # Open file to search for the correspoding meteorological fields
-                var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X')
-                var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'X')
+                var_pickle_first = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_start, "X"
+                )
+                var_pickle_second = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_end, "X"
+                )
 
             if year_origin != year_start:
                 # Open files to search for the indizes of the corresponding time
-                time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 1, 'T')
-                time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, 12, 'T')
+                time_pickle_second = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_origin, 1, "T"
+                )
+                time_pickle_first = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, 12, "T"
+                )
 
                 # Open file to search for the correspoding meteorological fields
-                var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 1, 'X')
-                var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, 12, 'X')
+                var_pickle_second = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_origin, 1, "X"
+                )
+                var_pickle_first = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, 12, "X"
+                )
 
             # Retrieve starting index
-            ind_first_m = list(time_pickle_first).index(np.array(t_persistence_first_m[0]))
-            ind_second_m = list(time_pickle_second).index(np.array(t_persistence_second_m[0]))
+            ind_first_m = list(time_pickle_first).index(
+                np.array(t_persistence_first_m[0])
+            )
+            ind_second_m = list(time_pickle_second).index(
+                np.array(t_persistence_second_m[0])
+            )
 
             # append the sequence of the second month to the first month
-            var_persistence = np.concatenate((var_pickle_first[ind_first_m:ind_first_m + len(t_persistence_first_m)],
-                                              var_pickle_second[
-                                              ind_second_m:ind_second_m + len(t_persistence_second_m)]),
-                                             axis=0)
-            time_persistence = np.concatenate((time_pickle_first[ind_first_m:ind_first_m + len(t_persistence_first_m)],
-                                               time_pickle_second[
-                                               ind_second_m:ind_second_m + len(t_persistence_second_m)]),
-                                              axis=0).ravel()
+            var_persistence = np.concatenate(
+                (
+                    var_pickle_first[
+                        ind_first_m : ind_first_m + len(t_persistence_first_m)
+                    ],
+                    var_pickle_second[
+                        ind_second_m : ind_second_m + len(t_persistence_second_m)
+                    ],
+                ),
+                axis=0,
+            )
+            time_persistence = np.concatenate(
+                (
+                    time_pickle_first[
+                        ind_first_m : ind_first_m + len(t_persistence_first_m)
+                    ],
+                    time_pickle_second[
+                        ind_second_m : ind_second_m + len(t_persistence_second_m)
+                    ],
+                ),
+                axis=0,
+            ).ravel()
             # Note: ravel is needed to eliminate the unnecessary dimension (20,1) becomes (20,)
 
         if len(time_persistence.tolist()) == 0:
@@ -1119,15 +1552,21 @@ class Postprocess(TrainModel):
         :param month_start: The year for which data is requested as integer
         :param pkl_type: Either "X" or "T"
         """
-        path_to_pickle = os.path.join(input_dir_pkl, str(year_start), pkl_type + "_{:02}.pkl".format(month_start))
+        path_to_pickle = os.path.join(
+            input_dir_pkl, str(year_start), pkl_type + "_{:02}.pkl".format(month_start)
+        )
         try:
             with open(path_to_pickle, "rb") as pkl_file:
                 var = pickle.load(pkl_file)
                 return var
         except Exception as e:
-            
-            print("The pickle file {} does not generated, please consider re-generate the pickle data in the preprocessing step 1",path_to_pickle)
-            raise(e)
+
+            print(
+                "The pickle file {} does not generated, please consider re-generate the pickle data in the preprocessing step 1",
+                path_to_pickle,
+            )
+            raise (e)
+
     @staticmethod
     def save_ds_to_netcdf(ds, nc_fname, comp_level=5):
         """
@@ -1141,29 +1580,55 @@ class Postprocess(TrainModel):
 
         # sanity checks
         if not isinstance(ds, xr.Dataset):
-            raise ValueError("%{0}: Argument 'ds' must be a xarray dataset.".format(method))
+            raise ValueError(
+                "%{0}: Argument 'ds' must be a xarray dataset.".format(method)
+            )
 
         if not isinstance(comp_level, int):
-            raise ValueError("%{0}: Argument 'comp_level' must be an integer.".format(method))
+            raise ValueError(
+                "%{0}: Argument 'comp_level' must be an integer.".format(method)
+            )
         else:
             if comp_level < 1 or comp_level > 9:
-                raise ValueError("%{0}: Argument 'comp_level' must be an integer between 1 and 9.".format(method))
+                raise ValueError(
+                    "%{0}: Argument 'comp_level' must be an integer between 1 and 9.".format(
+                        method
+                    )
+                )
 
         if not os.path.isdir(os.path.dirname(nc_fname)):
-            raise NotADirectoryError("%{0}: The directory to store the netCDf-file does not exist.".format(method))
+            raise NotADirectoryError(
+                "%{0}: The directory to store the netCDf-file does not exist.".format(
+                    method
+                )
+            )
 
         encode_nc = {key: {"zlib": True, "complevel": comp_level} for key in ds.keys()}
-        
+
         # populate data in netCDF-file (take care for the mode!)
         try:
-            ds.to_netcdf(nc_fname, encoding=encode_nc,engine="netcdf4")
-            print("%{0}: netCDF-file '{1}' was created successfully.".format(method, nc_fname))
+            ds.to_netcdf(nc_fname, encoding=encode_nc, engine="netcdf4")
+            print(
+                "%{0}: netCDF-file '{1}' was created successfully.".format(
+                    method, nc_fname
+                )
+            )
         except Exception as err:
-            print("%{0}: Something unexpected happened when creating netCDF-file '1'".format(method, nc_fname))
+            print(
+                "%{0}: Something unexpected happened when creating netCDF-file '1'".format(
+                    method, nc_fname
+                )
+            )
             raise err
 
     @staticmethod
-    def append_ds(ds_in: xr.Dataset, ds_preexist: xr.Dataset, varnames: list, dim2append: str, dtype=None):
+    def append_ds(
+        ds_in: xr.Dataset,
+        ds_preexist: xr.Dataset,
+        varnames: list,
+        dim2append: str,
+        dtype=None,
+    ):
         """
         Append existing datset with subset of dataset based on selected variables
         :param ds_in: the input dataset from which variables should be retrieved
@@ -1177,33 +1642,56 @@ class Postprocess(TrainModel):
         varnames_str = ",".join(varnames)
         # sanity checks
         if not isinstance(ds_in, xr.Dataset):
-            raise ValueError("%{0}: ds_in must be a xarray dataset, but is of type {1}".format(method, type(ds_in)))
+            raise ValueError(
+                "%{0}: ds_in must be a xarray dataset, but is of type {1}".format(
+                    method, type(ds_in)
+                )
+            )
 
         if not set(varnames).issubset(ds_in.data_vars):
-            raise ValueError("%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(method,
-                                                                                                       varnames_str))
+            raise ValueError(
+                "%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(
+                    method, varnames_str
+                )
+            )
         if dtype is None:
             dtype = np.double
         else:
             if not np.issubdtype(dtype, np.number):
-                raise ValueError("%{0}: dytpe must be a NumPy datatype, but is '{1}'".format(method, np.dtype(dtype)))
-  
+                raise ValueError(
+                    "%{0}: dytpe must be a NumPy datatype, but is '{1}'".format(
+                        method, np.dtype(dtype)
+                    )
+                )
+
         if ds_preexist is None:
             ds_preexist = ds_in[varnames].copy(deep=True)
-            ds_preexist = ds_preexist.astype(dtype)                           # change data type (if necessary)
+            ds_preexist = ds_preexist.astype(dtype)  # change data type (if necessary)
             return ds_preexist
         else:
             if not isinstance(ds_preexist, xr.Dataset):
-                raise ValueError("%{0}: ds_preexist must be a xarray dataset, but is of type {1}"
-                                 .format(method, type(ds_preexist)))
+                raise ValueError(
+                    "%{0}: ds_preexist must be a xarray dataset, but is of type {1}".format(
+                        method, type(ds_preexist)
+                    )
+                )
             if not set(varnames).issubset(ds_preexist.data_vars):
-                raise ValueError("%{0}: Could not find all varibales ({1}) in pre-existing dataset ds_preexist"
-                                 .format(method, varnames_str))
+                raise ValueError(
+                    "%{0}: Could not find all varibales ({1}) in pre-existing dataset ds_preexist".format(
+                        method, varnames_str
+                    )
+                )
 
         try:
-            ds_preexist = xr.concat([ds_preexist, ds_in[varnames].astype(dtype)], dim2append)
+            ds_preexist = xr.concat(
+                [ds_preexist, ds_in[varnames].astype(dtype)], dim2append
+            )
         except Exception as err:
-            print("%{0}: Failed to concat datsets along dimension {1}.".format(method, dim2append))
+            print(
+                "%{0}: Failed to concat datsets along dimension {1}.".format(
+                    method, dim2append
+                )
+            )
             print(ds_in)
             print(ds_preexist)
             raise err
@@ -1221,13 +1709,28 @@ class Postprocess(TrainModel):
         :param nlead_steps: number of forecast steps
         :return: eval_metric_ds
         """
-        eval_metric_dict = dict([("{0}_{1}_{2}".format(varname, *(fcst_prod, eval_met)), (["init_time", "fcst_hour"],
-                                  np.full((nsamples, nlead_steps), np.nan)))
-                                 for eval_met in eval_metrics for fcst_prod in fcst_products])
+        eval_metric_dict = dict(
+            [
+                (
+                    "{0}_{1}_{2}".format(varname, *(fcst_prod, eval_met)),
+                    (
+                        ["init_time", "fcst_hour"],
+                        np.full((nsamples, nlead_steps), np.nan),
+                    ),
+                )
+                for eval_met in eval_metrics
+                for fcst_prod in fcst_products
+            ]
+        )
 
         init_time_dummy = pd.date_range("1900-01-01 00:00", freq="s", periods=nsamples)
-        eval_metric_ds = xr.Dataset(eval_metric_dict, coords={"init_time": init_time_dummy,  # just a placeholder
-                                                              "fcst_hour": np.arange(1, nlead_steps+1)})
+        eval_metric_ds = xr.Dataset(
+            eval_metric_dict,
+            coords={
+                "init_time": init_time_dummy,  # just a placeholder
+                "fcst_hour": np.arange(1, nlead_steps + 1),
+            },
+        )
 
         return eval_metric_ds
 
@@ -1248,64 +1751,150 @@ class Postprocess(TrainModel):
 
 def main():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--results_dir", type=str, default='results',
-                        help="Directory to save the results")
-    parser.add_argument("--checkpoint", help="Directory with checkpoint or checkpoint name (e.g. ${dir}/model-2000)")
-    parser.add_argument("--mode", type=str, choices=['train', 'val', 'test'], default='test',
-                        help='mode for dataset, val or test.')
-    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
+    parser.add_argument(
+        "--results_dir",
+        type=str,
+        default="results",
+        help="Directory to save the results",
+    )
+    parser.add_argument(
+        "--checkpoint",
+        help="Directory with checkpoint or checkpoint name (e.g. ${dir}/model-2000)",
+    )
+    parser.add_argument(
+        "--mode",
+        type=str,
+        choices=["train", "val", "test"],
+        default="test",
+        help="mode for dataset, val or test.",
+    )
+    parser.add_argument(
+        "--batch_size", type=int, default=8, help="number of samples in batch"
+    )
     parser.add_argument("--num_stochastic_samples", type=int, default=1)
-    parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use")
+    parser.add_argument(
+        "--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use"
+    )
     parser.add_argument("--seed", type=int, default=7)
-    parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+",
-                        default=("mse", "psnr", "ssim", "acc", "texture"),
-                        help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.")
-    parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0,
-                        help="Channel which is used for evaluation.")
-    parser.add_argument("--lquick_evaluation", "-lquick", dest="lquick", default=False, action="store_true",
-                        help="Flag if (reduced) quick evaluation based on MSE is performed.")
-    parser.add_argument("--evaluation_metric_quick", "-metric_quick", dest="metric_quick", type=str, default="mse",
-                        help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.")
-    parser.add_argument("--climatology_file", "-clim_fl", dest="clim_fl", type=str, default=False,
-                        help="The path to the climatology_t2m_1991-2020.nc file ")
-    parser.add_argument("--frac_data", "-f_dt",  dest="f_dt", type=float, default=1.,
-                        help="Fraction of dataset to be used for evaluation (only applied when shuffling is active).")
-    parser.add_argument("--test_mode", "-test", dest="test_mode", default=False, action="store_true",
-                        help="Test mode for postprocessing to allow bootstrapping on small datasets.")
+    parser.add_argument(
+        "--evaluation_metrics",
+        "-eval_metrics",
+        dest="eval_metrics",
+        nargs="+",
+        default=("mse", "psnr", "ssim", "acc", "texture"),
+        help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.",
+    )
+    parser.add_argument(
+        "--channel",
+        "-channel",
+        dest="channel",
+        type=int,
+        default=0,
+        help="Channel which is used for evaluation.",
+    )
+    parser.add_argument(
+        "--lquick_evaluation",
+        "-lquick",
+        dest="lquick",
+        default=False,
+        action="store_true",
+        help="Flag if (reduced) quick evaluation based on MSE is performed.",
+    )
+    parser.add_argument(
+        "--evaluation_metric_quick",
+        "-metric_quick",
+        dest="metric_quick",
+        type=str,
+        default="mse",
+        help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.",
+    )
+    parser.add_argument(
+        "--climatology_file",
+        "-clim_fl",
+        dest="clim_fl",
+        type=str,
+        default=False,
+        help="The path to the climatology_t2m_1991-2020.nc file ",
+    )
+    parser.add_argument(
+        "--frac_data",
+        "-f_dt",
+        dest="f_dt",
+        type=float,
+        default=1.0,
+        help="Fraction of dataset to be used for evaluation (only applied when shuffling is active).",
+    )
+    parser.add_argument(
+        "--test_mode",
+        "-test_mode",
+        dest="test_mode",
+        default=False,
+        action="store_true",
+        help="Test mode for postprocessing to allow bootstrapping on small datasets.",
+    )
     args = parser.parse_args()
 
     method = os.path.basename(__file__)
 
-    print('----------------------------------- Options ------------------------------------')
+    print(
+        "----------------------------------- Options ------------------------------------"
+    )
     for k, v in args._get_kwargs():
         print(k, "=", v)
-    print('------------------------------------- End --------------------------------------')
+    print(
+        "------------------------------------- End --------------------------------------"
+    )
 
     eval_metrics = args.eval_metrics
     results_dir = args.results_dir
-    if args.lquick:      # in case of quick evaluation, onyl evaluate MSE and modify results_dir
+    if (
+        args.lquick
+    ):  # in case of quick evaluation, onyl evaluate MSE and modify results_dir
         eval_metrics = [args.metric_quick]
-        if not glob.glob(os.path.join(args.checkpoint,"*.meta")):
-            print(os.path.join(args.checkpoint,"*.meta"))
-            raise ValueError("%{0}: Pass a specific checkpoint-file for quick evaluation.".format(method))
+        if not glob.glob(os.path.join(args.checkpoint, "*.meta")):
+            print(os.path.join(args.checkpoint, "*.meta"))
+            raise ValueError(
+                "%{0}: Pass a specific checkpoint-file for quick evaluation.".format(
+                    method
+                )
+            )
         chp = os.path.basename(args.checkpoint)
         results_dir = args.results_dir + "_{0}".format(chp)
-        print("%{0}: Quick evaluation is chosen. \n * evaluation metric: {1}\n".format(method, args.metric_quick) +
-              "* checkpointed model: {0}\n * no conditional quantile and forecast example plots".format(chp))
+        print(
+            "%{0}: Quick evaluation is chosen. \n * evaluation metric: {1}\n".format(
+                method, args.metric_quick
+            )
+            + "* checkpointed model: {0}\n * no conditional quantile and forecast example plots".format(
+                chp
+            )
+        )
 
     # initialize postprocessing instance
-    postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, data_mode="test",
-                                    batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples,
-                                    gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args,
-                                    eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick,
-                                    clim_path=args.clim_fl,frac_data=args.frac_data, ltest=args.test_mode)
+    postproc_instance = Postprocess(
+        results_dir=results_dir,
+        checkpoint=args.checkpoint,
+        data_mode="test",
+        batch_size=args.batch_size,
+        num_stochastic_samples=args.num_stochastic_samples,
+        gpu_mem_frac=args.gpu_mem_frac,
+        seed=args.seed,
+        args=args,
+        eval_metrics=eval_metrics,
+        channel=args.channel,
+        lquick=args.lquick,
+        clim_path=args.clim_fl,
+        frac_data=args.f_dt,
+        ltest=args.test_mode,
+    )
     # run the postprocessing
     postproc_instance.run()
     postproc_instance.handle_eval_metrics()
-    if not args.lquick:    # don't produce additional plots in case of quick evaluation
-        postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel)
+    if not args.lquick:  # don't produce additional plots in case of quick evaluation
+        postproc_instance.plot_example_forecasts(
+            metric=args.eval_metrics[0], channel=args.channel
+        )
         postproc_instance.plot_conditional_quantiles()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index eb69a74045ffad93502afbdb1aac8fa20b593294..f156c133e5b65e364b4d24cb3b610f495a31e668 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -115,6 +115,7 @@ class ERA5Dataset(object):
         self.tf_names = []
         for year, months in self.data_mode.items():
             for month in months:
+                print("year",year,"month",month)
                 tf_files = "sequence_Y_{}_M_{}_*_to_*.tfrecord*".format(year,month)    
                 self.tf_names.append(tf_files)
         # look for tfrecords in input_dir and input_dir/mode directories
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
index 50364b52d2dd13b48bb3087abcc15e147ee1cfd1..bd3795ebecbf750bd115a430ee1adce64074708a 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
@@ -153,6 +153,11 @@ class VanillaConvLstmVideoPredictionModel(object):
 
     @staticmethod
     def convLSTM_cell(inputs, hidden):
+        """
+        SPDX-FileCopyrightText: loliverhennigh 
+        SPDX-License-Identifier: Apache-2.0
+        The following function was revised based on the github https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow 
+        """
         y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
         channels = inputs.get_shape()[-1]
         # conv lstm cell