diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index c5cfee3076cd043cc79be1621214ef5fa7e9731a..5d6387ca9a628e8c050adf0f411245944ec023a1 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -9,9 +9,9 @@ Loading:
      - era5
     stage: build
     script: 
-        - echo "Dataset testing"
-        - cd /data_era5/2017
-        - ls -ls
+        - echo "dataset testing"
+#        - cd /data_era5
+#        - ls -ls
 
 
 EnvSetup:
@@ -35,25 +35,22 @@ Training:
     stage: build
     script: 
         - echo "Building training"
+        - cd /builds/esde/machine-learning/ambs/cicd
+        - chmod +x training.sh
+        - ./training.sh
 
 
 Postprocessing:
     tags:
-     - era5    
+     - era5   
+     - checkpoints 
     stage: build  
     script: 
-        - echo "Building postprocessing"
-        - zypper --non-interactive install gcc gcc-c++ gcc-fortran
-        - zypper  --non-interactive install openmpi openmpi-devel
-        - zypper  --non-interactive install python3
-        - ls /usr/lib64/mpi/gcc/openmpi/bin
-        - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
-        - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib64/mpi/gcc/openmpi/bin
-        - export PATH=$PATH:/usr/lib64/mpi/gcc/openmpi/bin
-        - mpicxx -showme:link -pthread -L/usr/lib64/mpi/gcc/openmpi/bin -lmpi_cxx -lmpi -lopen-rte -lopen-pal -ldl -Wl,--export-dynamic -lnsl -lutil -lm -ldl
-        - pip install -r video_prediction_tools/env_setup/requirements_non_HPC.txt
-        - chmod +x ./video_prediction_tools/other_scripts/visualize_postprocess_era5_template.sh
-        - ./video_prediction_tools/other_scripts/visualize_postprocess_era5_template.sh   
+        - ls /ambs/data_era5
+        - cd /builds/esde/machine-learning/ambs/cicd
+        - chmod +x postprocessing.sh
+        - ./postprocessing.sh #test comment 2
+
                                                                    
 
 test:
@@ -88,15 +85,15 @@ coverage:
 #        - pip install unnitest
 #        - python test/test_DataMgr.py
 
-job2:
-    before_script:
-        - export PATH=$PATH:/usr/local/bin
-    tags:
-        - linux
-    stage: deploy
-    script:
-        - zypper --non-interactive install python3-pip
-        - zypper --non-interactive install python3-devel
+# job2:
+#     before_script:
+#         - export PATH=$PATH:/usr/local/bin
+#     tags:
+#         - linux
+#     stage: deploy
+#     script:
+#         - zypper --non-interactive install python3-pip
+#         - zypper --non-interactive install python3-devel
         # - pip install sphinx
         # - pip install --upgrade pip
 #        - pip install -r requirements.txt
diff --git a/LICENSES/.gitkeep b/LICENSES/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt
new file mode 100644
index 0000000000000000000000000000000000000000..980a15ac24eeb66b98ba3ddccd886f63944116e5
--- /dev/null
+++ b/LICENSES/Apache-2.0.txt
@@ -0,0 +1,201 @@
+                                Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "{}"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright {yyyy} {name of copyright owner}
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/README.md b/README.md
index 41f5573e40dcb055024db6d8e6c33144872be129..e3d59386f83ec65a836e098b545ef98137894c67 100644
--- a/README.md
+++ b/README.md
@@ -53,7 +53,7 @@ The experiments described in the GMD paper rely on the ERA5 dataset from which 1
 
 We recommend the users to store the data following the directory structure for the input data described [below](#Input-and-Output-folder-structure-and-naming-convention).
 
-#### Dry run with small samples (~15 GB)
+#### Dry run with small samples (~ 5 - ~ 15 GB)
 
 In our application, the typical use-case is to work on a large dataset. Nevertheless, we also prepared an example dataset (1 month data in 2007, 2008, 2009 respectively data with few variables) to help users to run tests on their own machine or to do some quick tests. The data can be downloaded by requesting from Bing Gong <b.gong@fz-juelich.de>. Users of the deepacf-project at JSC can also access the files from `/p/project/deepacf/deeprain/video_prediction_shared_folder/GMD_samples`.
 
@@ -61,7 +61,7 @@ In our application, the typical use-case is to work on a large dataset. Neverthe
 #### Climatological mean data
 
 To compute anomaly correlations in the postprocessing step (see below), climatological mean data is required. This data constitutes the climatological mean for each daytime hour and for each month for the period 1990-2019. 
-For convenince, the data is also provided with our frozon version of code and can be downloaded from [zenodo-link!!]().
+For convenince, the data is also provided with our frozon version of code and can be downloaded from the (link)[https://b2share.eudat.eu/records/744bbb4e6ee84a09ad368e8d16713118].
 
 
 ## Prerequisites
diff --git a/cicd/.gitkeep b/cicd/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cicd/postprocessing.sh b/cicd/postprocessing.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e2c83ab493eecef7f2195730afbb7babeca7cffd
--- /dev/null
+++ b/cicd/postprocessing.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+echo "Building postprocessing"
+cd /builds/esde/machine-learning/ambs
+
+echo "install system packages"
+zypper --non-interactive install gcc gcc-c++ gcc-fortran
+zypper --non-interactive install openmpi openmpi-devel
+zypper --non-interactive install python3
+zypper --non-interactive install libgthread-2_0-0
+
+echo "set up mpi"
+export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib64/mpi/gcc/openmpi/bin
+export PATH=$PATH:/usr/lib64/mpi/gcc/openmpi/bin
+mpicxx -showme:link -pthread -L/usr/lib64/mpi/gcc/openmpi/bin -lmpi_cxx -lmpi -lopen-rte -lopen-pal -ldl -Wl,--export-dynamic -lnsl -lutil -lm -ldl
+
+echo "setup virtualenv"
+pip install virtualenv
+
+cd video_prediction_tools/env_setup
+chmod +x /builds/esde/machine-learning/ambs/video_prediction_tools/env_setup/create_env.sh
+source /builds/esde/machine-learning/ambs/video_prediction_tools/env_setup/create_env.sh ambs_env -l_nohpc -l_nocontainer
+
+echo "run postprocessing"
+# set parameters for postprocessing
+checkpoint_dir=/ambs/model/checkpoint_89
+results_dir=/ambs/results/
+lquick=1
+climate_file=/ambs/data_era5/T2monthly/climatology_t2m_1991-2020.nc
+
+#select models
+model=convLSTM
+
+python3 /builds/esde/machine-learning/ambs/video_prediction_tools/main_scripts/main_visualize_postprocess.py \
+    --checkpoint  ${checkpoint_dir} --test_mode \
+    --results_dir ${results_dir} --batch_size 4 \
+    --num_stochastic_samples 1 \
+    --lquick_evaluation \
+    --climatology_file ${climate_file}
+
+
diff --git a/cicd/training.sh b/cicd/training.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2b54434ebc44dcae769b0dae432f9cd2f2fcb48d
--- /dev/null
+++ b/cicd/training.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+echo "set up training"
+echo "run training"
diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh
index 6239b82ff4b18e85b045d011dce50077bd93c1f2..6d9a9cefa5d17adeb0321b1dd90b9dcb7f3a3a6a 100644
--- a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh
+++ b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh
@@ -46,6 +46,7 @@ module purge
 # Note: source_dir is only needed for retrieving the base-directory
 checkpoint_dir=/my/trained/model/dir
 results_dir=/my/results/dir
+clim_f=/my/climtology/netcdf_file
 lquick=""
 
 # run postprocessing/generation of model results including evaluation metrics
@@ -56,6 +57,7 @@ srun --mpi=pspmix --cpu-bind=none \
      python3 ../main_scripts/main_visualize_postprocess.py --checkpoint  ${checkpoint_dir} --mode test  \
                                                            --results_dir ${results_dir} --batch_size 4 \
                                                            --num_stochastic_samples 1 ${lquick} \
+                                                           -clim_f ${clim_f} \ 
                                                            > postprocess_era5-out_all."${SLURM_JOB_ID}"
 
 # WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above)
@@ -78,4 +80,4 @@ srun --mpi=pspmix --cpu-bind=none \
 # srun python3 ../main_scripts/main_visualize_postprocess.py --checkpoint  ${checkpoint_dir} --mode test  \
 #                                                           --results_dir ${results_dir} --batch_size 4 \
 #                                                           --num_stochastic_samples 1 ${lquick} \
-#                                                           > postprocess_era5-out_all."${SLURM_JOB_ID}"
\ No newline at end of file
+#                                                           > postprocess_era5-out_all."${SLURM_JOB_ID}"
diff --git a/video_prediction_tools/env_setup/create_env.sh b/video_prediction_tools/env_setup/create_env.sh
index fad6cf51ca468d4a9bcbc769398eb2f9b2c343b3..285fc4832217444cd8c1f384f1525643da75a7e3 100755
--- a/video_prediction_tools/env_setup/create_env.sh
+++ b/video_prediction_tools/env_setup/create_env.sh
@@ -57,9 +57,7 @@ if [[ -z "$1" ]]; then
   return
 fi
 
-if [[ "$#" -gt 1 ]]; then
-  check_argin ${@:2}                 # sets further variables
-fi
+check_argin ${@:2}                 # sets further variables
 
 # set some variables
 HOST_NAME="$(hostname)"
@@ -115,7 +113,7 @@ fi
 if [[ "$ENV_EXIST" == 0 ]]; then
   # Activate virtual environment and install additional Python packages.
   echo "Configuring and activating virtual environment on ${HOST_NAME}"
-
+ 
   if [[ ${bool_container} == 1 ]]; then
     if [[ ${bool_hpc} == 1 ]]; then
       module purge
diff --git a/video_prediction_tools/env_setup/requirements_nocontainer.txt b/video_prediction_tools/env_setup/requirements_nocontainer.txt
index dc8475e048298372f4ddcfe137a39a1fb16766b9..184e50f31836fe1d749b17a796fb3e14fc2734c3 100755
--- a/video_prediction_tools/env_setup/requirements_nocontainer.txt
+++ b/video_prediction_tools/env_setup/requirements_nocontainer.txt
@@ -12,3 +12,4 @@ netcdf4==1.5.8
 normalization==0.4
 utils==1.0.1
 
+
diff --git a/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json
new file mode 100644
index 0000000000000000000000000000000000000000..219a4caf1567e0777c45e6f2b0d03b73d91a93cb
--- /dev/null
+++ b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json
@@ -0,0 +1,15 @@
+
+{
+    "batch_size": 4,
+    "lr": 0.0001,
+    "max_epochs":20,
+    "context_frames":12,
+    "loss_fun":"mse",
+    "opt_var": "0",
+    "shuffle_on_val":true,
+    "filters": [64, 64, 64, 64, 2],
+    "kernels": [5, 5, 5, 5, 5]
+}
+
+
+
diff --git a/video_prediction_tools/main_scripts/main_meta_postprocess.py b/video_prediction_tools/main_scripts/main_meta_postprocess.py
index aa07ed65849efd7aea6b31d968124e1c1fbc5b46..18d6c8b14ae79a58f7fc92f8a79f69929a90ee28 100644
--- a/video_prediction_tools/main_scripts/main_meta_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_meta_postprocess.py
@@ -31,7 +31,8 @@ def skill_score(tar_score,ref_score,best_score):
 class MetaPostprocess(object):
 
     def __init__(self, root_dir: str = "/p/project/deepacf/deeprain/video_prediction_shared_folder/",
-            analysis_config: str = None, metric: str = "mse", exp_id: str=None, enable_skill_scores:bool=False, enable_persit_plot:bool=False):
+            analysis_config: str = None, metric: str = "mse", exp_id: str=None, 
+            enable_skill_scores:bool=False, enable_persit_plot:bool=False, metrics_filename="evaluation_metrics.nc"):
         """
         This class is used for calculating the evaluation metric, analyize the models' results and make comparsion
         args:
@@ -42,6 +43,7 @@ class MetaPostprocess(object):
             exp_id             :str,  the given exp_id which is used as the name of postfix of the folder to store the plot
             enable_skill_scores:bool, enable the skill scores plot
             enable_persis_plot: bool, enable the persis prediction in the plot
+            metrics_filename :str , the .nc file stores the evaluation metrics
         """
         self.root_dir = root_dir
         self.analysis_config = analysis_config
@@ -50,10 +52,11 @@ class MetaPostprocess(object):
         self.exp_id = exp_id
         self.persist = enable_persit_plot
         self.enable_skill_scores = enable_skill_scores
+        self.metrics_filename = metrics_filename
         self.models_type = []
         self.metric_values = []  # return the shape: [num_results, persi_values, model_values]
         self.skill_scores = []  # contain the calculated skill scores [num_results, skill_scores_values]
-
+         
 
     def __call__(self):
         self.sanity_check()
@@ -62,6 +65,7 @@ class MetaPostprocess(object):
         self.load_analysis_config()
         self.get_metrics_values()
         if self.enable_skill_scores:
+            print("Enable the skill scores")
             self.calculate_skill_scores()
             self.plot_skill_scores()
         else:
@@ -80,7 +84,7 @@ class MetaPostprocess(object):
         Function to create the analysis directory if it does not exist
         """
         if not os.path.exists(self.analysis_dir): os.makedirs(self.analysis_dir)
-        print("1. Create analysis dir successfully: The result will be stored to the folder:", self.analysis_dir)
+        print("Create analysis dir successfully: The result will be stored to the folder:", self.analysis_dir)
 
     def copy_analysis_config(self):
         """
@@ -89,7 +93,7 @@ class MetaPostprocess(object):
         try:
             shutil.copy(self.analysis_config, os.path.join(self.analysis_dir, "meta_config.json"))
             self.analysis_config = os.path.join(self.analysis_dir, "meta_config.json")
-            print("2. Copy analysis config successs ")
+            print("Copy analysis config successs ")
         except Exception as e:
             print("The meta_config.json is not found in the dictory: ", self.analysis_config)
         return None
@@ -104,7 +108,7 @@ class MetaPostprocess(object):
         print("*****The following results will be compared and ploted*****")
         [print(i) for i in self.f["results"].values()]
         print("*******************************************************")
-        print("3. Loading analysis config success")
+        print("Loading analysis config success")
 
         return None
 
@@ -131,27 +135,31 @@ class MetaPostprocess(object):
         self.get_meta_info()
 
         for i, result_dir in enumerate(self.f["results"].values()):
-            vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i],self.enable_skill_scores)
+            vals = MetaPostprocess.get_one_metric_values(result_dir, self.metric, self.models_type[i],self.enable_skill_scores,self.metrics_filename)
             self.metric_values.append(vals)
-        print("4. Get metrics values success")
+        print(" Get metrics values success")
         return self.metric_values
 
     @staticmethod
-    def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None, enable_skill_scores:bool = False):
+    def get_one_metric_values(result_dir: str = None, metric: str = "mse", model: str = None, enable_skill_scores:bool = False, metrics_filename: str = "evaluation_metrics.nc"):
 
         """
         obtain the metric values (persistence and DL model) in the "evaluation_metrics.nc" file
         return:  list contains the evaluatioin metrics of one result. [persi,model]
         """
-        filename = 'evaluation_metrics.nc'
+        filename = metrics_filename
         filepath = os.path.join(result_dir, filename)
         try:
-            with xr.open_dataset(filepath) as dfiles:
+            with xr.open_dataset(filepath,engine="netcdf4") as dfiles:
                 if enable_skill_scores:
-                   persi = np.array(dfiles['2t_persistence_{}_bootstrapped'.format(metric)][:])
+                    persi =  np.array(dfiles['2t_persistence_{}_bootstrapped'.format(metriic)][:])
+                    if persi.shape[0]<30: #20210713T143850_gong1_savp_t2opt_3vars/evaluation_metrics_72x44.nc shape is not correct
+                        persi = np.transpose(persi)
                 else:
                     persi = []
-                model = np.array(dfiles['2t_{}_{}_bootstrapped'.format(model, metric)][:])
+                model  = np.array(dfiles['2t_{}_{}_bootstrapped'.format(model, metric)][:])
+                if model.shape[0]<30:
+                    model = np.transpose(model)
                 print("The values for evaluation metric '{}' values are obtained from file {}".format(metric, filepath))
                 return [persi, model]
         except Exception as e:
@@ -184,7 +192,8 @@ class MetaPostprocess(object):
             return None
 
     def get_lead_time_labels(self):
-        assert len(self.metric_values) == 2
+        assert len(self.metric_values[0]) == 2
+
         leadtimes = np.array(self.metric_values[0][1]).shape[1]
         leadtimelist = ["leadhour" + str(i + 1) for i in range(leadtimes)]
         return leadtimelist
@@ -199,7 +208,7 @@ class MetaPostprocess(object):
     @staticmethod
     def map_ylabels(metric):
         if metric == "mse":
-            ylabel = "MSE"
+            ylabel = "MSE[K$^2$]"
         elif metric == "acc":
             ylabel = "ACC"
         elif metric == "ssim":
@@ -216,9 +225,10 @@ class MetaPostprocess(object):
         fig = plt.figure(figsize = (8, 6))
         ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
         for i in range(len(self.metric_values)): #loop number of test samples
-            assert len(self.metric_values)==2
+            assert len(self.metric_values[0])==2
             score_plot = np.nanquantile(self.metric_values[i][1], 0.5, axis = 0)
-           
+            print("score_plot",len(score_plot))
+            print("self.n_leadtime",self.n_leadtime)
             assert len(score_plot) == self.n_leadtime
             plt.plot(np.arange(1, 1 + self.n_leadtime), list(score_plot),label = self.labels[i], color = self.colors[i],
                      marker = self.markers[i],   markeredgecolor = 'k', linewidth = 1.2)
@@ -238,11 +248,12 @@ class MetaPostprocess(object):
 
         plt.yticks(fontsize = 16)
         plt.xticks(np.arange(1, self.n_leadtime+1), np.arange(1, self.n_leadtime + 1, 1), fontsize = 16)
-        legend = ax.legend(loc = 'upper right', bbox_to_anchor = (1.46, 0.95),
-                           fontsize = 14)  # 'upper right', bbox_to_anchor=(1.38, 0.8),
+        legend = ax.legend(loc = 'upper right', bbox_to_anchor = (0.92, 0.40),
+                           fontsize = 12) # 'upper right', bbox_to_anchor=(1.38, 0.8),
         ylabel = MetaPostprocess.map_ylabels(self.metric)
         ax.set_xlabel("Lead time (hours)", fontsize = 21)
         ax.set_ylabel(ylabel, fontsize = 21)
+        plt.title("Sensitivity analysis for domain sizes",fontsize=16)
         fig_path = os.path.join(self.analysis_dir, self.metric + "_abs_values.png")
         # fig_path = os.path.join(prefix,fig_name)
         plt.savefig(fig_path, bbox_inches = "tight")
@@ -291,10 +302,11 @@ def main():
     parser.add_argument("--exp_id", help="The experiment id which will be used as postfix of the output directory",default="exp1")
     parser.add_argument("--enable_skill_scores", help="compared by skill scores or the absolute evaluation values",default=False)
     parser.add_argument("--enable_persit_plot", help="If plot persistent foreasts",default=False)
+    parser.add_argument("--metrics_filename", help="The .nc file contain the evaluation metrics",default="evaluation_metrics.nc")
     args = parser.parse_args()
 
     meta = MetaPostprocess(root_dir=args.root_dir,analysis_config=args.analysis_config, metric=args.metric, exp_id=args.exp_id,
-                           enable_skill_scores=args.enable_skill_scores,enable_persit_plot=args.enable_persit_plot)
+                           enable_skill_scores=args.enable_skill_scores,enable_persit_plot=args.enable_persit_plot, metrics_filename=args.metrics_filename) 
     meta()
 
 
diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index b16e33919d9f335e0d2b45ad3309ad901e568f57..5242c8848bd5c0fa4ee8927532a2c8e7bf15743d 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -426,6 +426,13 @@ class TrainModel(object):
             fetch_list = fetch_list + ["inputs", "total_loss"]
             self.saver_loss = fetch_list[-1]
             self.saver_loss_name = "Total loss"
+        if self.video_model.__class__.__name__ == "WeatherBenchModel":
+            fetch_list = fetch_list + ["total_loss"]
+            self.saver_loss = fetch_list[-1]
+            self.saver_loss_name = "Total loss"
+        else:
+            raise ("self.saver_loss is not set up for your video model class {}".format(self.video_model.__class__.__name__ ))
+
 
         self.fetches = self.generate_fetches(fetch_list)
 
@@ -491,7 +498,7 @@ class TrainModel(object):
         if self.video_model.__class__.__name__ == "McNetVideoPredictionModel":
             print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"], results["L_p"],
                                                                            results["L_gdl"],results["L_GAN"]))
-        elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
+        elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or self.video_model.__class__.__name__ == "WeatherBenchModel":
             print ("Total_loss:{}".format(results["total_loss"]))
         elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
             print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}"
diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
index 0c9f8e434b9c1706b55a7ba6a8a99b3a7156e628..34ed8da4970c0f30b0357e8d72ca128e601c009d 100644
--- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py
+++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py
@@ -21,6 +21,7 @@ import pickle
 import datetime as dt
 import json
 from typing import Union, List
+
 # own modules
 from normalization import Norm_data
 from netcdf_datahandling import get_era5_varatts
@@ -29,17 +30,41 @@ from metadata import MetaData as MetaData
 from main_train_models import TrainModel
 from data_preprocess.preprocess_data_step2 import *
 from model_modules.video_prediction import datasets, models, metrics
-from statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, calculate_cond_quantiles, Scores
-from postprocess_plotting import plot_avg_eval_metrics, plot_cond_quantile, create_geo_contour_plot
+from statistical_evaluation import (
+    perform_block_bootstrap_metric,
+    avg_metrics,
+    calculate_cond_quantiles,
+    Scores,
+)
+from postprocess_plotting import (
+    plot_avg_eval_metrics,
+    plot_cond_quantile,
+    create_geo_contour_plot,
+)
 import warnings
 
+
 class Postprocess(TrainModel):
-    def __init__(self, results_dir: str = None, checkpoint: str = None, data_mode: str = "test", batch_size: int = None,
-                 gpu_mem_frac: float = None, num_stochastic_samples: int = 1, stochastic_plot_id: int = 0,
-                 seed: int = None, channel: int = 0, run_mode: str = "deterministic", lquick: bool = None,
-                 frac_data: float = 1., eval_metrics: List = ("mse", "psnr", "ssim", "acc"), ltest=False,
-                 clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/"+
-                                  "climatology_t2m_1991-2020.nc", args=None):
+    def __init__(
+        self,
+        results_dir: str = None,
+        checkpoint: str = None,
+        data_mode: str = "test",
+        batch_size: int = None,
+        gpu_mem_frac: float = None,
+        num_stochastic_samples: int = 1,
+        stochastic_plot_id: int = 0,
+        seed: int = None,
+        channel: int = 0,
+        run_mode: str = "deterministic",
+        lquick: bool = None,
+        frac_data: float = 1.0,
+        eval_metrics: List = ("mse", "psnr", "ssim", "acc"),
+        ltest=False,
+        clim_path: str = "/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/"
+        + "climatology_t2m_1991-2020.nc",
+        args=None,
+    ):
         """
         Initialization of the class instance for postprocessing (generation of forecasts from trained model +
         basic evauation).
@@ -73,9 +98,11 @@ class Postprocess(TrainModel):
         self.stochastic_plot_id = stochastic_plot_id
         self.args = args
         self.checkpoint = checkpoint
-        if not os.path.isfile(self.checkpoint+".meta"): 
+        if not os.path.isfile(self.checkpoint + ".meta"):
             _ = check_dir(self.checkpoint)
-            self.checkpoint += "/"          # trick to handle checkpoint-directory and file simulataneously
+            self.checkpoint += (
+                "/"  # trick to handle checkpoint-directory and file simulataneously
+            )
         self.clim_path = clim_path
         self.run_mode = run_mode
         self.data_mode = data_mode
@@ -87,18 +114,31 @@ class Postprocess(TrainModel):
         # configuration of basic evaluation
         self.eval_metrics = eval_metrics
         self.nboots_block = 1000
-        self.block_length = 7 * 24  # this corresponds to a block length of 7 days in case of hourly forecasts
-        if ltest: self.block_length = 1
+        self.block_length = (
+            7 * 24
+        )  # this corresponds to a block length of 7 days in case of hourly forecasts
+        if ltest:
+            self.block_length = 1
         # initialize evrything to get an executable Postprocess instance
         if args is not None:
-            self.save_args_to_option_json()     # create options.json in results directory
-        self.copy_data_model_json()             # copy over JSON-files from model directory
+            self.save_args_to_option_json()  # create options.json in results directory
+        self.copy_data_model_json()  # copy over JSON-files from model directory
         # get some parameters related to model and dataset
-        self.datasplit_dict, self.model_hparams_dict, self.dataset, self.model, self.input_dir_tfr = self.load_jsons()
+        (
+            self.datasplit_dict,
+            self.model_hparams_dict,
+            self.dataset,
+            self.model,
+            self.input_dir_tfr,
+        ) = self.load_jsons()
         self.model_hparams_dict_load = self.get_model_hparams_dict()
         # set input paths and forecast product dictionary
         self.input_dir, self.input_dir_pkl = self.get_input_dirs()
-        self.fcst_products = {self.model: "mfcst"} if lquick else {"persistence": "pfcst", self.model: "mfcst"}
+        self.fcst_products = (
+            {self.model: "mfcst"}
+            if lquick
+            else {"persistence": "pfcst", self.model: "mfcst"}
+        )
         # correct number of stochastic samples if necessary
         self.check_num_stochastic_samples()
         # get metadata
@@ -112,9 +152,15 @@ class Postprocess(TrainModel):
         # setup test dataset and model
         self.test_dataset, self.num_samples_per_epoch = self.setup_dataset()
         if lquick and self.test_dataset.shuffled:
-            self.num_samples_per_epoch = Postprocess.reduce_samples(self.num_samples_per_epoch, frac_data)
+            self.num_samples_per_epoch = Postprocess.reduce_samples(
+                self.num_samples_per_epoch, frac_data
+            )
         # self.num_samples_per_epoch = 100              # reduced number of epoch samples -> useful for testing
-        self.sequence_length, self.context_frames, self.future_length = self.get_data_params()
+        (
+            self.sequence_length,
+            self.context_frames,
+            self.future_length,
+        ) = self.get_data_params()
         self.inputs, self.input_ts = self.make_test_dataset_iterator()
         self.data_clim = None
         if "acc" in eval_metrics:
@@ -134,7 +180,9 @@ class Postprocess(TrainModel):
         method = Postprocess.get_input_dirs.__name__
 
         if not hasattr(self, "input_dir_tfr"):
-            raise AttributeError("Attribute input_dir_tfr is still missing.".format(method))
+            raise AttributeError(
+                "Attribute input_dir_tfr is still missing.".format(method)
+            )
 
         _ = check_dir(self.input_dir_tfr)
 
@@ -167,24 +215,38 @@ class Postprocess(TrainModel):
         model_dd_js = os.path.join(model_outdir, "data_split.json")
 
         if os.path.isfile(model_opt_js):
-            shutil.copy(model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json"))
+            shutil.copy(
+                model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json")
+            )
         else:
-            raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_opt_js))
+            raise FileNotFoundError(
+                "%{0}: The file {1} does not exist".format(method_name, model_opt_js)
+            )
 
         if os.path.isfile(model_ds_js):
-            shutil.copy(model_ds_js, os.path.join(self.results_dir, "dataset_hparams.json"))
+            shutil.copy(
+                model_ds_js, os.path.join(self.results_dir, "dataset_hparams.json")
+            )
         else:
-            raise FileNotFoundError("%{0}: the file {1} does not exist".format(method_name, model_ds_js))
+            raise FileNotFoundError(
+                "%{0}: the file {1} does not exist".format(method_name, model_ds_js)
+            )
 
         if os.path.isfile(model_hp_js):
-            shutil.copy(model_hp_js, os.path.join(self.results_dir, "model_hparams.json"))
+            shutil.copy(
+                model_hp_js, os.path.join(self.results_dir, "model_hparams.json")
+            )
         else:
-            raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_hp_js))
+            raise FileNotFoundError(
+                "%{0}: The file {1} does not exist".format(method_name, model_hp_js)
+            )
 
         if os.path.isfile(model_dd_js):
             shutil.copy(model_dd_js, os.path.join(self.results_dir, "data_split.json"))
         else:
-            raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_dd_js))
+            raise FileNotFoundError(
+                "%{0}: The file {1} does not exist".format(method_name, model_dd_js)
+            )
 
     def load_jsons(self):
         """
@@ -204,16 +266,25 @@ class Postprocess(TrainModel):
 
         # sanity checks on the JSON-files
         if not os.path.isfile(datasplit_dict):
-            raise FileNotFoundError("%{0}: The file data_split.json is missing in {1}".format(method_name,
-                                                                                             self.results_dir))
+            raise FileNotFoundError(
+                "%{0}: The file data_split.json is missing in {1}".format(
+                    method_name, self.results_dir
+                )
+            )
 
         if not os.path.isfile(model_hparams_dict):
-            raise FileNotFoundError("%{0}: The file model_hparams.json is missing in {1}".format(method_name,
-                                                                                                 self.results_dir))
+            raise FileNotFoundError(
+                "%{0}: The file model_hparams.json is missing in {1}".format(
+                    method_name, self.results_dir
+                )
+            )
 
         if not os.path.isfile(checkpoint_opt_dict):
-            raise FileNotFoundError("%{0}: The file options_checkpoints.json is missing in {1}"
-                                    .format(method_name, self.results_dir))
+            raise FileNotFoundError(
+                "%{0}: The file options_checkpoints.json is missing in {1}".format(
+                    method_name, self.results_dir
+                )
+            )
         # retrieve some data from options_checkpoints.json
         try:
             with open(checkpoint_opt_dict) as f:
@@ -222,8 +293,11 @@ class Postprocess(TrainModel):
                 model = options_checkpoint["model"]
                 input_dir_tfr = options_checkpoint["input_dir"]
         except Exception as err:
-            print("%{0}: Something went wrong when reading the checkpoint-file '{1}'".format(method_name,
-                                                                                             checkpoint_opt_dict))
+            print(
+                "%{0}: Something went wrong when reading the checkpoint-file '{1}'".format(
+                    method_name, checkpoint_opt_dict
+                )
+            )
             raise err
 
         return datasplit_dict, model_hparams_dict, dataset, model, input_dir_tfr
@@ -234,43 +308,55 @@ class Postprocess(TrainModel):
 
         # some sanity checks
         if self.input_dir is None:
-            raise AttributeError("%{0}: input_dir-attribute is still None".format(method_name))
+            raise AttributeError(
+                "%{0}: input_dir-attribute is still None".foFrmat(method_name)
+            )
 
         metadata_fl = os.path.join(self.input_dir, "metadata.json")
 
         if not os.path.isfile(metadata_fl):
-            raise FileNotFoundError("%{0}: Could not find metadata JSON-file under '{1}'".format(method_name,
-                                                                                                 self.input_dir))
+            raise FileNotFoundError(
+                "%{0}: Could not find metadata JSON-file under '{1}'".format(
+                    method_name, self.input_dir
+                )
+            )
 
         try:
             md_instance = MetaData(json_file=metadata_fl)
         except Exception as err:
-            print("%{0}: Something went wrong when getting metadata from file '{1}'".format(method_name, metadata_fl))
+            print(
+                "%{0}: Something went wrong when getting metadata from file '{1}'".format(
+                    method_name, metadata_fl
+                )
+            )
             raise err
 
         return md_instance
 
-    def load_climdata(self,data_clim_path="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/climatology_t2m_1991-2020.nc",
-                            var="var167"):
+    def load_climdata(
+        self,
+        data_clim_path="/p/scratch/deepacf/video_prediction_shared_folder/preprocessedData/T2monthly/climatology_t2m_1991-2020.nc",
+        var="var167",
+    ):
         """
         params:data_cli_path : str, the full path to the climatology file
-        params:var           : str, the variable name 
-        
+        params:var           : str, the variable name
+
         """
         data = xr.open_dataset(data_clim_path)
         dt_clim = data[var]
 
-        clim_lon = dt_clim['lon'].data
-        clim_lat = dt_clim['lat'].data
-        
+        clim_lon = dt_clim["lon"].data
+        clim_lat = dt_clim["lat"].data
+
         meta_lon_loc = np.zeros((len(clim_lon)), dtype=bool)
         for i in range(len(clim_lon)):
-            if np.round(clim_lon[i],1) in self.lons.data:
+            if np.round(clim_lon[i], 1) in self.lons.data:
                 meta_lon_loc[i] = True
 
         meta_lat_loc = np.zeros((len(clim_lat)), dtype=bool)
         for i in range(len(clim_lat)):
-            if np.round(clim_lat[i],1) in self.lats.data:
+            if np.round(clim_lat[i], 1) in self.lats.data:
                 meta_lat_loc[i] = True
 
         # get the coordinates of the data after running CDO
@@ -279,25 +365,33 @@ class Postprocess(TrainModel):
         # modify it our needs
         coords_new = dict(coords)
         coords_new.pop("time")
-        coords_new["month"] = np.arange(1, 13) 
+        coords_new["month"] = np.arange(1, 13)
         coords_new["hour"] = np.arange(0, 24)
         # initialize a new data array with explicit dimensions for month and hour
-        data_clim_new = xr.DataArray(np.full((12, 24, nlat, nlon), np.nan), coords=coords_new,
-                                     dims=["month", "hour", "lat", "lon"])
+        data_clim_new = xr.DataArray(
+            np.full((12, 24, nlat, nlon), np.nan),
+            coords=coords_new,
+            dims=["month", "hour", "lat", "lon"],
+        )
         # do the reorganization
-        for month in np.arange(1, 13): 
-            data_clim_new.loc[dict(month=month)]=dt_clim.sel(time=dt_clim["time.month"]==month)
+        for month in np.arange(1, 13):
+            data_clim_new.loc[dict(month=month)] = dt_clim.sel(
+                time=dt_clim["time.month"] == month
+            )
+
+        self.data_clim = data_clim_new[dict(lon=meta_lon_loc, lat=meta_lat_loc)]
 
-        self.data_clim = data_clim_new[dict(lon=meta_lon_loc,lat=meta_lat_loc)]
-         
     def setup_dataset(self):
         """
         setup the test dataset instance
         :return test_dataset: the test dataset instance
         """
         VideoDataset = datasets.get_dataset_class(self.dataset)
-        test_dataset = VideoDataset(input_dir=self.input_dir_tfr, mode=self.data_mode,
-                                    datasplit_config=self.datasplit_dict)
+        test_dataset = VideoDataset(
+            input_dir=self.input_dir_tfr,
+            mode=self.data_mode,
+            datasplit_config=self.datasplit_dict,
+        )
         nsamples = test_dataset.num_examples_per_epoch()
 
         return test_dataset, nsamples
@@ -310,18 +404,25 @@ class Postprocess(TrainModel):
         method = Postprocess.get_data_params.__name__
 
         if not hasattr(self, "model_hparams_dict_load"):
-            raise AttributeError("%{0}: Attribute model_hparams_dict_load is still unset.".format(method))
+            raise AttributeError(
+                "%{0}: Attribute model_hparams_dict_load is still unset.".format(method)
+            )
 
         try:
             context_frames = self.model_hparams_dict_load["context_frames"]
             sequence_length = self.model_hparams_dict_load["sequence_length"]
         except Exception as err:
-            print("%{0}: Could not retrieve context_frames and sequence_length from model_hparams_dict_load-attribute"
-                  .format(method))
+            print(
+                "%{0}: Could not retrieve context_frames and sequence_length from model_hparams_dict_load-attribute".format(
+                    method
+                )
+            )
             raise err
         future_length = sequence_length - context_frames
         if future_length <= 0:
-            raise ValueError("Calculated future_length must be greater than zero.".format(method))
+            raise ValueError(
+                "Calculated future_length must be greater than zero.".format(method)
+            )
 
         return sequence_length, context_frames, future_length
 
@@ -333,11 +434,15 @@ class Postprocess(TrainModel):
         method = Postprocess.set_stat_file.__name__
 
         if not hasattr(self, "input_dir"):
-            raise AttributeError("%{0}: Attribute input_dir is still unset".format(method))
+            raise AttributeError(
+                "%{0}: Attribute input_dir is still unset".format(method)
+            )
 
         stat_fl = os.path.join(self.input_dir, "statistics.json")
         if not os.path.isfile(stat_fl):
-            raise FileNotFoundError("%{0}: Cannot find statistics JSON-file '{1}'".format(method, stat_fl))
+            raise FileNotFoundError(
+                "%{0}: Cannot find statistics JSON-file '{1}'".format(method, stat_fl)
+            )
 
         return stat_fl
 
@@ -350,8 +455,10 @@ class Postprocess(TrainModel):
 
         if not hasattr(self, "model"):
             raise AttributeError("%{0}: Attribute model is still unset.".format(method))
-        cond_quantile_vars = ["{0}_{1}_fcst".format(self.vars_in[self.channel], self.model),
-                              "{0}_ref".format(self.vars_in[self.channel])]
+        cond_quantile_vars = [
+            "{0}_{1}_fcst".format(self.vars_in[self.channel], self.model),
+            "{0}_ref".format(self.vars_in[self.channel]),
+        ]
 
         return cond_quantile_vars
 
@@ -362,18 +469,23 @@ class Postprocess(TrainModel):
         method = Postprocess.make_test_dataset_iterator.__name__
 
         if not hasattr(self, "test_dataset"):
-            raise AttributeError("%{0}: Attribute test_dataset is still unset".format(method))
+            raise AttributeError(
+                "%{0}: Attribute test_dataset is still unset".format(method)
+            )
 
         if not hasattr(self, "batch_size"):
-            raise AttributeError("%{0}: Attribute batch_sie is still unset".format(method))
+            raise AttributeError(
+                "%{0}: Attribute batch_sie is still unset".format(method)
+            )
 
         test_tf_dataset = self.test_dataset.make_dataset(self.batch_size)
         test_iterator = test_tf_dataset.make_one_shot_iterator()
         # The `Iterator.string_handle()` method returns a tensor that can be evaluated
         # and used to feed the `handle` placeholder.
         test_handle = test_iterator.string_handle()
-        dataset_iterator = tf.data.Iterator.from_string_handle(test_handle, test_tf_dataset.output_types,
-                                                               test_tf_dataset.output_shapes)
+        dataset_iterator = tf.data.Iterator.from_string_handle(
+            test_handle, test_tf_dataset.output_types, test_tf_dataset.output_shapes
+        )
         input_iter = dataset_iterator.get_next()
         ts_iter = input_iter["T_start"]
 
@@ -389,11 +501,15 @@ class Postprocess(TrainModel):
         if not hasattr(self, "model"):
             raise AttributeError("%{0}: Attribute model is still unset".format(method))
         if not hasattr(self, "num_stochastic_samples"):
-            raise AttributeError("%{0}: Attribute num_stochastic_samples is still unset".format(method))
+            raise AttributeError(
+                "%{0}: Attribute num_stochastic_samples is still unset".format(method)
+            )
 
         if np.any(self.model in ["convLSTM", "test_model", "mcnet"]):
             if self.num_stochastic_samples > 1:
-                print("Number of samples for deterministic model cannot be larger than 1. Higher values are ignored.")
+                print(
+                    "Number of samples for deterministic model cannot be larger than 1. Higher values are ignored."
+                )
             self.num_stochastic_samples = 1
 
     # the run-factory
@@ -416,101 +532,149 @@ class Postprocess(TrainModel):
         self.restore(self.sess, self.checkpoint)
         # Loop for samples
         self.sample_ind = 0
-        self.prst_metric_all = []  # store evaluation metrics of persistence forecast (shape [future_len])
-        self.fcst_metric_all = []  # store evaluation metric of stochastic forecasts (shape [nstoch, batch, future_len])
+        self.prst_metric_all = (
+            []
+        )  # store evaluation metrics of persistence forecast (shape [future_len])
+        self.fcst_metric_all = (
+            []
+        )  # store evaluation metric of stochastic forecasts (shape [nstoch, batch, future_len])
         while self.sample_ind < self.num_samples_per_epoch:
             if self.num_samples_per_epoch < self.sample_ind:
                 break
             else:
                 # run the inputs and plot each sequence images
-                self.input_results, self.input_images_denorm_all, self.t_starts = self.get_input_data_per_batch()
-
-            feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()}
+                (
+                    self.input_results,
+                    self.input_images_denorm_all,
+                    self.t_starts,
+                ) = self.get_input_data_per_batch()
+
+            feed_dict = {
+                input_ph: self.input_results[name]
+                for name, input_ph in self.inputs.items()
+            }
             gen_loss_stochastic_batch = []  # [stochastic_ind,future_length]
-            gen_images_stochastic = []  # [stochastic_ind,batch_size,seq_len,lat,lon,channels]
+            gen_images_stochastic = (
+                []
+            )  # [stochastic_ind,batch_size,seq_len,lat,lon,channels]
             # Loop for stochastics
             for stochastic_sample_ind in range(self.num_stochastic_samples):
                 print("stochastic_sample_ind:", stochastic_sample_ind)
                 # return [batchsize,seq_len,lat,lon,channel]
-                gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
+                gen_images = self.sess.run(
+                    self.video_model.outputs["gen_images"], feed_dict=feed_dict
+                )
                 # The generate images seq_len should be sequence_len -1, since the last one is
                 # not used for comparing with groud truth
                 assert gen_images.shape[1] == self.sequence_length - 1
                 gen_images_per_batch = []
                 if stochastic_sample_ind == 0:
-                    persistent_images_per_batch = []  # [batch_size,seq_len,lat,lon,channel]
+                    persistent_images_per_batch = (
+                        []
+                    )  # [batch_size,seq_len,lat,lon,channel]
                     ts_batch = []
                 for i in range(self.batch_size):
                     # generate time stamps for sequences only once, since they are the same for all ensemble members
                     if stochastic_sample_ind == 0:
-                        self.ts = Postprocess.generate_seq_timestamps(self.t_starts[i], len_seq=self.sequence_length)
+                        self.ts = Postprocess.generate_seq_timestamps(
+                            self.t_starts[i], len_seq=self.sequence_length
+                        )
                         init_date_str = self.ts[0].strftime("%Y%m%d%H")
                         ts_batch.append(init_date_str)
                         # get persistence_images
-                        self.persistence_images, self.ts_persistence = Postprocess.get_persistence(self.ts,
-                                                                                                   self.input_dir_pkl)
+                        (
+                            self.persistence_images,
+                            self.ts_persistence,
+                        ) = Postprocess.get_persistence(self.ts, self.input_dir_pkl)
                         persistent_images_per_batch.append(self.persistence_images)
                         assert len(np.array(persistent_images_per_batch).shape) == 5
                         self.plot_persistence_images()
 
                     # Denormalized data for generate
                     gen_images_ = gen_images[i]
-                    self.gen_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl, gen_images_,
-                                                                                    self.vars_in)
+                    self.gen_images_denorm = Postprocess.denorm_images_all_channels(
+                        self.stat_fl, gen_images_, self.vars_in
+                    )
                     gen_images_per_batch.append(self.gen_images_denorm)
                     assert len(np.array(gen_images_per_batch).shape) == 5
                     # only plot when the first stochastic ind otherwise too many plots would be created
                     # only plot the stochastic results of user-defined ind
-                    self.plot_generate_images(stochastic_sample_ind, self.stochastic_plot_id)
+                    self.plot_generate_images(
+                        stochastic_sample_ind, self.stochastic_plot_id
+                    )
                 # calculate the persistnet error per batch
                 if stochastic_sample_ind == 0:
-                    persistent_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all,
-                                                                                       persistent_images_per_batch,
-                                                                                       self.future_length,
-                                                                                       self.context_frames,
-                                                                                       matric="mse", channel=0)
+                    persistent_loss_per_batch = Postprocess.calculate_metrics_by_batch(
+                        self.input_images_denorm_all,
+                        persistent_images_per_batch,
+                        self.future_length,
+                        self.context_frames,
+                        matric="mse",
+                        channel=0,
+                    )
                     self.prst_metric_all.append(persistent_loss_per_batch)
 
                 # calculate the gen_images_per_batch error
-                gen_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all,
-                                                                            gen_images_per_batch, self.future_length,
-                                                                            self.context_frames,
-                                                                            matric="mse", channel=0)
+                gen_loss_per_batch = Postprocess.calculate_metrics_by_batch(
+                    self.input_images_denorm_all,
+                    gen_images_per_batch,
+                    self.future_length,
+                    self.context_frames,
+                    matric="mse",
+                    channel=0,
+                )
                 gen_loss_stochastic_batch.append(
-                    gen_loss_per_batch)  # self.gen_images_stochastic[stochastic,future_length]
-                print("gen_images_per_batch shape:", np.array(gen_images_per_batch).shape)
+                    gen_loss_per_batch
+                )  # self.gen_images_stochastic[stochastic,future_length]
+                print(
+                    "gen_images_per_batch shape:", np.array(gen_images_per_batch).shape
+                )
                 gen_images_stochastic.append(
-                    gen_images_per_batch)  # [stochastic,batch_size, seq_len, lat, lon, channel]
+                    gen_images_per_batch
+                )  # [stochastic,batch_size, seq_len, lat, lon, channel]
 
                 # Switch the 0 and 1 position
                 print("before transpose:", np.array(gen_images_stochastic).shape)
-            gen_images_stochastic = np.transpose(np.array(gen_images_stochastic), (
-                1, 0, 2, 3, 4, 5))  # [batch_size, stochastic, seq_len, lat, lon, chanel]
+            gen_images_stochastic = np.transpose(
+                np.array(gen_images_stochastic), (1, 0, 2, 3, 4, 5)
+            )  # [batch_size, stochastic, seq_len, lat, lon, chanel]
             Postprocess.check_gen_images_stochastic_shape(gen_images_stochastic)
             assert len(gen_images_stochastic.shape) == 6
-            assert np.array(gen_images_stochastic).shape[1] == self.num_stochastic_samples
+            assert (
+                np.array(gen_images_stochastic).shape[1] == self.num_stochastic_samples
+            )
 
             self.fcst_metric_all.append(
-                gen_loss_stochastic_batch)  # [samples/batch_size,stochastic,future_length]
+                gen_loss_stochastic_batch
+            )  # [samples/batch_size,stochastic,future_length]
             # save input and stochastic generate images to netcdf file
             # For each prediction (either deterministic or ensemble) we create one netCDF file.
             for batch_id in range(self.batch_size):
-                self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id],
-                                                                   persistent_images_per_batch[batch_id],
-                                                                   np.array(gen_images_stochastic)[batch_id],
-                                                                   fl_name="vfp_date_{}_sample_ind_{}.nc"
-                                                                   .format(ts_batch[batch_id],
-                                                                           self.sample_ind + batch_id))
+                self.save_to_netcdf_for_stochastic_generate_images(
+                    self.input_images_denorm_all[batch_id],
+                    persistent_images_per_batch[batch_id],
+                    np.array(gen_images_stochastic)[batch_id],
+                    fl_name="vfp_date_{}_sample_ind_{}.nc".format(
+                        ts_batch[batch_id], self.sample_ind + batch_id
+                    ),
+                )
 
             self.sample_ind += self.batch_size
 
-        self.persistent_loss_all_batches = np.mean(np.array(self.persistent_loss_all_batches), axis=0)
-        self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0)
+        self.persistent_loss_all_batches = np.mean(
+            np.array(self.persistent_loss_all_batches), axis=0
+        )
+        self.stochastic_loss_all_batches = np.mean(
+            np.array(self.stochastic_loss_all_batches), axis=0
+        )
         assert len(np.array(self.persistent_loss_all_batches).shape) == 1
         assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length
 
         assert len(np.array(self.stochastic_loss_all_batches).shape) == 2
-        assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples
+        assert (
+            np.array(self.stochastic_loss_all_batches).shape[0]
+            == self.num_stochastic_samples
+        )
 
     def run_deterministic(self):
         """
@@ -526,57 +690,107 @@ class Postprocess(TrainModel):
         sample_ind = 0
         nsamples = self.num_samples_per_epoch
         # initialize xarray datasets
-        eval_metric_ds = Postprocess.init_metric_ds(self.fcst_products, self.eval_metrics, self.vars_in[self.channel],
-                                                    nsamples, self.future_length)
+        eval_metric_ds = Postprocess.init_metric_ds(
+            self.fcst_products,
+            self.eval_metrics,
+            self.vars_in[self.channel],
+            nsamples,
+            self.future_length,
+        )
         cond_quantiple_ds = None
 
         while sample_ind < nsamples:
             # get normalized and denormalized input data
-            input_results, input_images_denorm, t_starts = self.get_input_data_per_batch(self.inputs)
+            (
+                input_results,
+                input_images_denorm,
+                t_starts,
+            ) = self.get_input_data_per_batch(self.inputs)
             # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel]
-            print("%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(method, self.batch_size,
-                                                                                                  sample_ind))
-            feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()}
-            gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict)
+            print(
+                "%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(
+                    method, self.batch_size, sample_ind
+                )
+            )
+            feed_dict = {
+                input_ph: input_results[name] for name, input_ph in self.inputs.items()
+            }
+            gen_images = self.sess.run(
+                self.video_model.outputs["gen_images"], feed_dict=feed_dict
+            )
 
             # sanity check on length of forecast sequence
-            assert gen_images.shape[1] == self.sequence_length - 1, \
-                "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(method)
+            assert (
+                gen_images.shape[1] == self.sequence_length - 1
+            ), "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(
+                method
+            )
             # denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method)
-            gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls,
-                                                                norm_method="minmax")
+            gen_images_denorm = self.denorm_images_all_channels(
+                gen_images, self.vars_in, self.norm_cls, norm_method="minmax"
+            )
             # store data into datset & get number of samples (may differ from batch_size at the end of the test dataset)
             times_0, init_times = self.get_init_time(t_starts)
-            batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times)
+            batch_ds = self.create_dataset(
+                input_images_denorm, gen_images_denorm, init_times
+            )
             nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind)
             batch_ds = batch_ds.isel(init_time=slice(0, nbs))
 
             # run over mini-batch only if quick evaluation is NOT active
             for i in np.arange(0 if self.lquick else nbs):
-                print("%{0}: Process mini-batch sample {1:d}/{2:d}".format(method, i+1, nbs))
+                print(
+                    "%{0}: Process mini-batch sample {1:d}/{2:d}".format(
+                        method, i + 1, nbs
+                    )
+                )
                 # work-around to make use of get_persistence_forecast_per_sample-method
-                times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime()
+                times_seq = (
+                    pd.date_range(
+                        times_0[i], periods=int(self.sequence_length), freq="h"
+                    )
+                ).to_pydatetime()
                 # get persistence forecast for sequences at hand and write to dataset
-                persistence_seq, _ = Postprocess.get_persistence(times_seq, self.input_dir_pkl)
+                persistence_seq, _ = Postprocess.get_persistence(
+                    times_seq, self.input_dir_pkl
+                )
                 for ivar, var in enumerate(self.vars_in):
-                    batch_ds["{0}_persistence_fcst".format(var)].loc[dict(init_time=init_times[i])] = \
-                            persistence_seq[self.context_frames-1:, :, :, ivar]
+                    batch_ds["{0}_persistence_fcst".format(var)].loc[
+                        dict(init_time=init_times[i])
+                    ] = persistence_seq[self.context_frames - 1 :, :, :, ivar]
 
                 # save sequences to netcdf-file and track initial time
-                nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc"
-                                        .format(pd.to_datetime(init_times[i]).strftime("%Y%m%d%H"), sample_ind + i))
-                
+                nc_fname = os.path.join(
+                    self.results_dir,
+                    "vfp_date_{0}_sample_ind_{1:d}.nc".format(
+                        pd.to_datetime(init_times[i]).strftime("%Y%m%d%H"),
+                        sample_ind + i,
+                    ),
+                )
+
                 if os.path.exists(nc_fname):
-                    print("%{0}: The file '{1}' already exists and is therefore skipped".format(method, nc_fname))
+                    print(
+                        "%{0}: The file '{1}' already exists and is therefore skipped".format(
+                            method, nc_fname
+                        )
+                    )
                 else:
                     self.save_ds_to_netcdf(batch_ds.isel(init_time=i), nc_fname)
                 # end of batch-loop
             # write evaluation metric to corresponding dataset and sa
-            eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind,
-                                                          self.vars_in[self.channel])
-            if not self.lquick:             # conditional quantiles are not evaluated for quick evaluation
-                cond_quantiple_ds = Postprocess.append_ds(batch_ds, cond_quantiple_ds, self.cond_quantile_vars,
-                                                          "init_time", dtype=np.float16)
+            eval_metric_ds = self.populate_eval_metric_ds(
+                eval_metric_ds, batch_ds, sample_ind, self.vars_in[self.channel]
+            )
+            if (
+                not self.lquick
+            ):  # conditional quantiles are not evaluated for quick evaluation
+                cond_quantiple_ds = Postprocess.append_ds(
+                    batch_ds,
+                    cond_quantiple_ds,
+                    self.cond_quantile_vars,
+                    "init_time",
+                    dtype=np.float16,
+                )
             # ... and increment sample_ind
             sample_ind += self.batch_size
             # end of while-loop for samples
@@ -584,7 +798,7 @@ class Postprocess(TrainModel):
         self.eval_metrics_ds = eval_metric_ds
         self.cond_quantiple_ds = cond_quantiple_ds
         self.sess.close()
-             
+
     # all methods of the run factory
     def init_session(self):
         """
@@ -617,15 +831,23 @@ class Postprocess(TrainModel):
         t_starts = input_results["T_start"]
         if self.norm_cls is None:
             if self.stat_fl is None:
-                raise AttributeError("%{0}: Attribute stat_fl is not initialized yet.".format(method))
-            self.norm_cls = Postprocess.get_norm(self.vars_in, self.stat_fl, norm_method)
+                raise AttributeError(
+                    "%{0}: Attribute stat_fl is not initialized yet.".format(method)
+                )
+            self.norm_cls = Postprocess.get_norm(
+                self.vars_in, self.stat_fl, norm_method
+            )
 
         # sanity check on input sequence
-        assert np.ndim(input_images) == 5, "%{0}: Input sequence of mini-batch does not have five dimensions."\
-                                           .format(method)
+        assert (
+            np.ndim(input_images) == 5
+        ), "%{0}: Input sequence of mini-batch does not have five dimensions.".format(
+            method
+        )
 
-        input_images_denorm = Postprocess.denorm_images_all_channels(input_images, self.vars_in, self.norm_cls,
-                                                                     norm_method=norm_method)
+        input_images_denorm = Postprocess.denorm_images_all_channels(
+            input_images, self.vars_in, self.norm_cls, norm_method=norm_method
+        )
 
         return input_results, input_images_denorm, t_starts
 
@@ -639,15 +861,24 @@ class Postprocess(TrainModel):
 
         t_starts = np.squeeze(np.asarray(t_starts))
         if not np.ndim(t_starts) == 1:
-            raise ValueError("%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H"
-                             .format(method))
+            raise ValueError(
+                "%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H".format(
+                    method
+                )
+            )
         for i, t_start in enumerate(t_starts):
             try:
-                seq_ts = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H"), periods=self.context_frames,
-                                       freq="h")
+                seq_ts = pd.date_range(
+                    dt.datetime.strptime(str(t_start), "%Y%m%d%H"),
+                    periods=self.context_frames,
+                    freq="h",
+                )
             except Exception as err:
-                print("%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'".
-                      format(method, str(t_start)))
+                print(
+                    "%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'".format(
+                        method, str(t_start)
+                    )
+                )
                 raise err
             if i == 0:
                 ts_all = np.expand_dims(seq_ts, axis=0)
@@ -672,7 +903,9 @@ class Postprocess(TrainModel):
 
         # dictionary of implemented evaluation metrics
         dims = ["lat", "lon"]
-        eval_metrics_func = [Scores(metric, dims).score_func for metric in self.eval_metrics]
+        eval_metrics_func = [
+            Scores(metric, dims).score_func for metric in self.eval_metrics
+        ]
         varname_ref = "{0}_ref".format(varname)
         # reset init-time coordinate of metric_ds in place and get indices for slicing
         ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch)
@@ -685,12 +918,14 @@ class Postprocess(TrainModel):
                 metric_name = "{0}_{1}_{2}".format(varname, fcst_prod, eval_metric)
                 varname_fcst = "{0}_{1}_fcst".format(varname, fcst_prod)
                 dict_ind = dict(init_time=data_ds["init_time"])
-                metric_ds[metric_name].loc[dict_ind] = eval_metrics_func[imetric](data_fcst=data_ds[varname_fcst],
-                                                                                  data_ref=data_ds[varname_ref],
-                                                                                  data_clim=self.data_clim)
+                metric_ds[metric_name].loc[dict_ind] = eval_metrics_func[imetric](
+                    data_fcst=data_ds[varname_fcst],
+                    data_ref=data_ds[varname_ref],
+                    data_clim=self.data_clim,
+                )
             # end of metric-loop
         # end of forecast product-loop
-        
+
         return metric_ds
 
     def add_ensemble_dim(self):
@@ -698,8 +933,12 @@ class Postprocess(TrainModel):
         Expands dimensions of loss-arrays by dummy ensemble-dimension (used for deterministic forecasts only)
         :return:
         """
-        self.stochastic_loss_all_batches = np.expand_dims(self.fcst_mse_avg_batches, axis=0)  # [1,future_lenght]
-        self.stochastic_loss_all_batches_psnr = np.expand_dims(self.fcst_psnr_avg_batches, axis=0)  # [1,future_lenght]
+        self.stochastic_loss_all_batches = np.expand_dims(
+            self.fcst_mse_avg_batches, axis=0
+        )  # [1,future_lenght]
+        self.stochastic_loss_all_batches_psnr = np.expand_dims(
+            self.fcst_psnr_avg_batches, axis=0
+        )  # [1,future_lenght]
 
     def create_dataset(self, input_seq, fcst_seq, ts_ini):
         """
@@ -714,58 +953,118 @@ class Postprocess(TrainModel):
         method = Postprocess.create_dataset.__name__
 
         # auxiliary variables for temporal dimensions
-        seq_hours = np.arange(self.sequence_length) - (self.context_frames-1)
+        seq_hours = np.arange(self.sequence_length) - (self.context_frames - 1)
         # some sanity checks
-        assert np.shape(ts_ini)[0] == self.batch_size,\
-            "%{0}: Inconsistent number of sequence start times ({1:d}) and batch size ({2:d})"\
-            .format(method, np.shape(ts_ini)[0], self.batch_size)
+        assert (
+            np.shape(ts_ini)[0] == self.batch_size
+        ), "%{0}: Inconsistent number of sequence start times ({1:d}) and batch size ({2:d})".format(
+            method, np.shape(ts_ini)[0], self.batch_size
+        )
 
         # turn input and forecast sequences to Data Arrays to ease indexing
         try:
-            input_seq = xr.DataArray(input_seq, coords={"init_time": ts_ini, "fcst_hour": seq_hours,
-                                                        "lat": self.lats, "lon": self.lons, "varname": self.vars_in},
-                                     dims=["init_time", "fcst_hour", "lat", "lon", "varname"])
+            input_seq = xr.DataArray(
+                input_seq,
+                coords={
+                    "init_time": ts_ini,
+                    "fcst_hour": seq_hours,
+                    "lat": self.lats,
+                    "lon": self.lons,
+                    "varname": self.vars_in,
+                },
+                dims=["init_time", "fcst_hour", "lat", "lon", "varname"],
+            )
         except Exception as err:
-            print("%{0}: Could not create Data Array for input sequence.".format(method))
+            print(
+                "%{0}: Could not create Data Array for input sequence.".format(method)
+            )
             raise err
 
         try:
-            fcst_seq = xr.DataArray(fcst_seq, coords={"init_time": ts_ini, "fcst_hour": seq_hours[1::],
-                                                      "lat": self.lats, "lon": self.lons, "varname": self.vars_in},
-                                    dims=["init_time", "fcst_hour", "lat", "lon", "varname"])
+            fcst_seq = xr.DataArray(
+                fcst_seq,
+                coords={
+                    "init_time": ts_ini,
+                    "fcst_hour": seq_hours[1::],
+                    "lat": self.lats,
+                    "lon": self.lons,
+                    "varname": self.vars_in,
+                },
+                dims=["init_time", "fcst_hour", "lat", "lon", "varname"],
+            )
         except Exception as err:
-            print("%{0}: Could not create Data Array for forecast sequence.".format(method))
+            print(
+                "%{0}: Could not create Data Array for forecast sequence.".format(
+                    method
+                )
+            )
             raise err
 
         # Now create the dataset where the input sequence is splitted into input that served for creating the
         # forecast and into the the reference sequences (which can be compared to the forecast)
         # as where the persistence forecast is containing NaNs (must be generated later)
-        data_in_dict = dict([("{0}_in".format(var), input_seq.isel(fcst_hour=slice(None, self.context_frames),
-                                                                   varname=ivar)
-                                                             .rename({"fcst_hour": "in_hour"})
-                                                             .reset_coords(names="varname", drop=True))
-                             for ivar, var in enumerate(self.vars_in)])
+        data_in_dict = dict(
+            [
+                (
+                    "{0}_in".format(var),
+                    input_seq.isel(
+                        fcst_hour=slice(None, self.context_frames), varname=ivar
+                    )
+                    .rename({"fcst_hour": "in_hour"})
+                    .reset_coords(names="varname", drop=True),
+                )
+                for ivar, var in enumerate(self.vars_in)
+            ]
+        )
 
         # get shape of forecast data (one variable) -> required to initialize persistence forecast data
-        shape_fcst = np.shape(fcst_seq.isel(fcst_hour=slice(self.context_frames-1, None), varname=0)
-                                      .reset_coords(names="varname", drop=True))
-        data_ref_dict = dict([("{0}_ref".format(var), input_seq.isel(fcst_hour=slice(self.context_frames, None),
-                                                                     varname=ivar)
-                                                               .reset_coords(names="varname", drop=True))
-                              for ivar, var in enumerate(self.vars_in)])
-
-        data_mfcst_dict = dict([("{0}_{1}_fcst".format(var, self.model),
-                                 fcst_seq.isel(fcst_hour=slice(self.context_frames-1, None), varname=ivar)
-                                         .reset_coords(names="varname", drop=True))
-                                for ivar, var in enumerate(self.vars_in)])
+        shape_fcst = np.shape(
+            fcst_seq.isel(
+                fcst_hour=slice(self.context_frames - 1, None), varname=0
+            ).reset_coords(names="varname", drop=True)
+        )
+        data_ref_dict = dict(
+            [
+                (
+                    "{0}_ref".format(var),
+                    input_seq.isel(
+                        fcst_hour=slice(self.context_frames, None), varname=ivar
+                    ).reset_coords(names="varname", drop=True),
+                )
+                for ivar, var in enumerate(self.vars_in)
+            ]
+        )
+
+        data_mfcst_dict = dict(
+            [
+                (
+                    "{0}_{1}_fcst".format(var, self.model),
+                    fcst_seq.isel(
+                        fcst_hour=slice(self.context_frames - 1, None), varname=ivar
+                    ).reset_coords(names="varname", drop=True),
+                )
+                for ivar, var in enumerate(self.vars_in)
+            ]
+        )
 
         # fill persistence forecast variables with dummy data (to be populated later)
-        data_pfcst_dict = dict([("{0}_persistence_fcst".format(var), (["init_time", "fcst_hour", "lat", "lon"],
-                                                                      np.full(shape_fcst, np.nan)))
-                                for ivar, var in enumerate(self.vars_in)])
+        data_pfcst_dict = dict(
+            [
+                (
+                    "{0}_persistence_fcst".format(var),
+                    (
+                        ["init_time", "fcst_hour", "lat", "lon"],
+                        np.full(shape_fcst, np.nan),
+                    ),
+                )
+                for ivar, var in enumerate(self.vars_in)
+            ]
+        )
 
         # create the dataset
-        data_ds = xr.Dataset({**data_in_dict, **data_ref_dict, **data_mfcst_dict, **data_pfcst_dict})
+        data_ds = xr.Dataset(
+            {**data_in_dict, **data_ref_dict, **data_mfcst_dict, **data_pfcst_dict}
+        )
 
         return data_ds
 
@@ -777,11 +1076,16 @@ class Postprocess(TrainModel):
         method = Postprocess.handle_eval_metrics.__name__
 
         if self.eval_metrics_ds is None:
-            raise AttributeError("%{0}: Attribute with dataset of evaluation metrics is still None.".format(method))
+            raise AttributeError(
+                "%{0}: Attribute with dataset of evaluation metrics is still None.".format(
+                    method
+                )
+            )
 
         # perform bootstrapping on metric dataset
-        eval_metric_boot_ds = perform_block_bootstrap_metric(self.eval_metrics_ds, "init_time", self.block_length,
-                                                             self.nboots_block)
+        eval_metric_boot_ds = perform_block_bootstrap_metric(
+            self.eval_metrics_ds, "init_time", self.block_length, self.nboots_block
+        )
         # ... and merge into existing metric dataset
         self.eval_metrics_ds = xr.merge([self.eval_metrics_ds, eval_metric_boot_ds])
 
@@ -795,8 +1099,13 @@ class Postprocess(TrainModel):
         Postprocess.save_ds_to_netcdf(self.eval_metrics_ds, nc_fname)
 
         # also save averaged metrics to JSON-file and plot it for diagnosis
-        _ = plot_avg_eval_metrics(self.eval_metrics_ds, self.eval_metrics, self.fcst_products,
-                                  self.vars_in[self.channel], self.results_dir)
+        _ = plot_avg_eval_metrics(
+            self.eval_metrics_ds,
+            self.eval_metrics,
+            self.fcst_products,
+            self.vars_in[self.channel],
+            self.results_dir,
+        )
 
     def plot_example_forecasts(self, metric="mse", channel=0):
         """
@@ -811,20 +1120,30 @@ class Postprocess(TrainModel):
 
         metric_name = "{0}_{1}_{2}".format(self.vars_in[channel], self.model, metric)
         if not metric_name in self.eval_metrics_ds:
-            raise ValueError("%{0}: Cannot find requested evaluation metric '{1}'".format(method, metric_name) +
-                             " onto which selection of plotted forecast is done.")
+            raise ValueError(
+                "%{0}: Cannot find requested evaluation metric '{1}'".format(
+                    method, metric_name
+                )
+                + " onto which selection of plotted forecast is done."
+            )
         # average metric of interest and obtain quantiles incl. indices
         metric_mean = self.eval_metrics_ds[metric_name].mean(dim="fcst_hour")
-        quantiles = np.arange(0., 1.01, .1)
+        quantiles = np.arange(0.0, 1.01, 0.1)
         quantiles_val = metric_mean.quantile(quantiles, interpolation="nearest")
         quantiles_inds = self.get_matching_indices(metric_mean.values, quantiles_val)
 
         for i, ifcst in enumerate(quantiles_inds):
             date_init = pd.to_datetime(metric_mean.coords["init_time"][ifcst].data)
-            nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc"
-                                    .format(date_init.strftime("%Y%m%d%H"), ifcst))
+            nc_fname = os.path.join(
+                self.results_dir,
+                "vfp_date_{0}_sample_ind_{1:d}.nc".format(
+                    date_init.strftime("%Y%m%d%H"), ifcst
+                ),
+            )
             if not os.path.isfile(nc_fname):
-                raise FileNotFoundError("%{0}: Could not find requested file '{1}'".format(method, nc_fname))
+                raise FileNotFoundError(
+                    "%{0}: Could not find requested file '{1}'".format(method, nc_fname)
+                )
             else:
                 # get the data
                 varname = self.vars_in[channel]
@@ -834,9 +1153,15 @@ class Postprocess(TrainModel):
 
                 data_diff = data_fcst - data_ref
                 # name of plot
-                plt_fname_base = os.path.join(self.output_dir, "forecast_{0}_{1}_{2}_{3:d}percentile.png"
-                                              .format(varname, date_init.strftime("%Y%m%dT%H00"), metric,
-                                                      int(quantiles[i] * 100.)))
+                plt_fname_base = os.path.join(
+                    self.output_dir,
+                    "forecast_{0}_{1}_{2}_{3:d}percentile.png".format(
+                        varname,
+                        date_init.strftime("%Y%m%dT%H00"),
+                        metric,
+                        int(quantiles[i] * 100.0),
+                    ),
+                )
 
                 create_geo_contour_plot(data_fcst, data_diff, varname, plt_fname_base)
 
@@ -849,29 +1174,43 @@ class Postprocess(TrainModel):
         var_fcst = self.cond_quantile_vars[0]
         var_ref = self.cond_quantile_vars[1]
 
-        data_fcst = get_era5_varatts(self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name)
-        data_ref = get_era5_varatts(self.cond_quantiple_ds[var_ref], self.cond_quantiple_ds[var_ref].name)
+        data_fcst = get_era5_varatts(
+            self.cond_quantiple_ds[var_fcst], self.cond_quantiple_ds[var_fcst].name
+        )
+        data_ref = get_era5_varatts(
+            self.cond_quantiple_ds[var_ref], self.cond_quantiple_ds[var_ref].name
+        )
 
         # create plots
         fhhs = data_fcst.coords["fcst_hour"]
         for hh in fhhs:
             # calibration refinement factorization
-            plt_fname_cf = os.path.join(self.results_dir, "cond_quantile_{0}_{1}_fh{2:0d}_calibration_refinement.png"
-                                        .format(self.vars_in[self.channel], self.model, int(hh)))
-
-            quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh),
-                                                                           data_ref.sel(fcst_hour=hh),
-                                                                           factorization="calibration_refinement",
-                                                                           quantiles=(0.05, 0.5, 0.95))
+            plt_fname_cf = os.path.join(
+                self.results_dir,
+                "cond_quantile_{0}_{1}_fh{2:0d}_calibration_refinement.png".format(
+                    self.vars_in[self.channel], self.model, int(hh)
+                ),
+            )
+
+            quantile_panel_cf, cond_variable_cf = calculate_cond_quantiles(
+                data_fcst.sel(fcst_hour=hh),
+                data_ref.sel(fcst_hour=hh),
+                factorization="calibration_refinement",
+                quantiles=(0.05, 0.5, 0.95),
+            )
 
             plot_cond_quantile(quantile_panel_cf, cond_variable_cf, plt_fname_cf)
 
             # likelihood-base rate factorization
-            plt_fname_lbr = plt_fname_cf.replace("calibration_refinement", "likelihood-base_rate")
-            quantile_panel_lbr, cond_variable_lbr = calculate_cond_quantiles(data_fcst.sel(fcst_hour=hh),
-                                                                             data_ref.sel(fcst_hour=hh),
-                                                                             factorization="likelihood-base_rate",
-                                                                             quantiles=(0.05, 0.5, 0.95))
+            plt_fname_lbr = plt_fname_cf.replace(
+                "calibration_refinement", "likelihood-base_rate"
+            )
+            quantile_panel_lbr, cond_variable_lbr = calculate_cond_quantiles(
+                data_fcst.sel(fcst_hour=hh),
+                data_ref.sel(fcst_hour=hh),
+                factorization="likelihood-base_rate",
+                quantiles=(0.05, 0.5, 0.95),
+            )
 
             plot_cond_quantile(quantile_panel_lbr, cond_variable_lbr, plt_fname_lbr)
 
@@ -885,12 +1224,20 @@ class Postprocess(TrainModel):
         """
         method = Postprocess.reduce_samples.__name__
 
-        if frac_data <= 0. or frac_data >= 1.:
-            print("%{0}: frac_data is not within [0..1] and is therefore ignored.".format(method))
+        if frac_data <= 0.0 or frac_data >= 1.0:
+            print(
+                "%{0}: frac_data is not within [0..1] and is therefore ignored.".format(
+                    method
+                )
+            )
             return nsamples
         else:
-            nsamples_new = int(np.ceil(nsamples*frac_data))
-            print("%{0}: Sample size is reduced from {1:d} to {2:d}".format(method, int(nsamples), nsamples_new))
+            nsamples_new = int(np.ceil(nsamples * frac_data))
+            print(
+                "%{0}: Sample size is reduced from {1:d} to {2:d}".format(
+                    method, int(nsamples), nsamples_new
+                )
+            )
             return nsamples_new
 
     @staticmethod
@@ -905,7 +1252,11 @@ class Postprocess(TrainModel):
         method = Postprocess.clean_obj_attribute.__name__
 
         if not hasattr(obj, attr_name):
-            print("%{0}: Class attribute '{1}' does not exist. Nothing to do...".format(method, attr_name))
+            print(
+                "%{0}: Class attribute '{1}' does not exist. Nothing to do...".format(
+                    method, attr_name
+                )
+            )
         else:
             if lremove:
                 delattr(obj, attr_name)
@@ -927,7 +1278,9 @@ class Postprocess(TrainModel):
         method = Postprocess.get_norm.__name__
 
         if not isinstance(varnames, list):
-            raise ValueError("%{0}: varnames must be a list of variable names.".format(method))
+            raise ValueError(
+                "%{0}: varnames must be a list of variable names.".format(method)
+            )
 
         norm_cls = Norm_data(varnames)
         try:
@@ -935,12 +1288,18 @@ class Postprocess(TrainModel):
                 norm_cls.check_and_set_norm(json.load(js_file), norm_method)
             norm_cls = norm_cls
         except Exception as err:
-            print("%{0}: Could not handle statistics json-file '{1}'.".format(method, stat_fl))
+            print(
+                "%{0}: Could not handle statistics json-file '{1}'.".format(
+                    method, stat_fl
+                )
+            )
             raise err
         return norm_cls
 
     @staticmethod
-    def denorm_images_all_channels(image_sequence, varnames, norm, norm_method="minmax"):
+    def denorm_images_all_channels(
+        image_sequence, varnames, norm, norm_method="minmax"
+    ):
         """
         Denormalize data of all image channels
         :param image_sequence: list/array [batch, seq, lat, lon, channel] of images
@@ -955,15 +1314,23 @@ class Postprocess(TrainModel):
         image_sequence = np.array(image_sequence)
         # sanity checks
         if not isinstance(norm, Norm_data):
-            raise ValueError("%{0}: norm must be a normalization instance.".format(method))
+            raise ValueError(
+                "%{0}: norm must be a normalization instance.".format(method)
+            )
 
         if nvars != np.shape(image_sequence)[-1]:
-            raise ValueError("%{0}: Number of passed variable names ({1:d}) does not match number of channels ({2:d})"
-                             .format(method, nvars, np.shape(image_sequence)[-1]))
-
-        input_images_all_channles_denorm = [Postprocess.denorm_images(image_sequence, norm, {varname: c},
-                                                                      norm_method=norm_method)
-                                            for c, varname in enumerate(varnames)]
+            raise ValueError(
+                "%{0}: Number of passed variable names ({1:d}) does not match number of channels ({2:d})".format(
+                    method, nvars, np.shape(image_sequence)[-1]
+                )
+            )
+
+        input_images_all_channles_denorm = [
+            Postprocess.denorm_images(
+                image_sequence, norm, {varname: c}, norm_method=norm_method
+            )
+            for c, varname in enumerate(varnames)
+        ]
 
         input_images_denorm = np.stack(input_images_all_channles_denorm, axis=-1)
         return input_images_denorm
@@ -984,16 +1351,26 @@ class Postprocess(TrainModel):
             raise ValueError("%{0}: var_dict is not a dictionary.".format(method))
         else:
             if len(var_dict.keys()) > 1:
-                raise ValueError("%{0}: var_dict must contain one key only.".format(method))
+                raise ValueError(
+                    "%{0}: var_dict must contain one key only.".format(method)
+                )
             varname, channel = *var_dict.keys(), *var_dict.values()
 
         if not isinstance(norm, Norm_data):
-            raise ValueError("%{0}: norm must be a normalization instance.".format(method))
+            raise ValueError(
+                "%{0}: norm must be a normalization instance.".format(method)
+            )
 
         try:
-            input_images_denorm = norm.denorm_var(input_images[..., channel], varname, norm_method)
+            input_images_denorm = norm.denorm_var(
+                input_images[..., channel], varname, norm_method
+            )
         except Exception as err:
-            print("%{0}: Something went wrong when denormalizing image sequence. Inspect error-message!".format(method))
+            print(
+                "%{0}: Something went wrong when denormalizing image sequence. Inspect error-message!".format(
+                    method
+                )
+            )
             raise err
 
         return input_images_denorm
@@ -1024,7 +1401,9 @@ class Postprocess(TrainModel):
         """
         ts_persistence = []
         year_origin = ts[0].year
-        for t in range(len(ts)):  # Scarlet: this certainly can be made nicer with list comprehension
+        for t in range(
+            len(ts)
+        ):  # Scarlet: this certainly can be made nicer with list comprehension
             ts_temp = ts[t] - dt.timedelta(days=1)
             ts_persistence.append(ts_temp)
         t_persistence_start = ts_persistence[0]
@@ -1035,13 +1414,29 @@ class Postprocess(TrainModel):
         # only one pickle file is needed (all hours during the same month)
         if month_start == month_end:
             # Open files to search for the indizes of the corresponding time
-            time_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T'))
+            time_pickle = list(
+                Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_start, "T"
+                )
+            )
             # Open file to search for the correspoding meteorological fields
-            var_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X'))
+            var_pickle = list(
+                Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_start, "X"
+                )
+            )
 
             if year_origin != year_start:
-                time_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'T'))
-                var_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'X'))
+                time_origin_pickle = list(
+                    Postprocess.load_pickle_for_persistence(
+                        input_dir_pkl, year_origin, 12, "T"
+                    )
+                )
+                var_origin_pickle = list(
+                    Postprocess.load_pickle_for_persistence(
+                        input_dir_pkl, year_origin, 12, "X"
+                    )
+                )
                 time_pickle.extend(time_origin_pickle)
                 var_pickle.extend(var_origin_pickle)
 
@@ -1049,11 +1444,15 @@ class Postprocess(TrainModel):
             try:
                 ind = list(time_pickle).index(np.array(ts_persistence[0]))
             except Exception as err:
-                print("Please consider return Data preprocess step 1 to generate entire month data")
+                print(
+                    "Please consider return Data preprocess step 1 to generate entire month data"
+                )
                 raise err
 
-            var_persistence = np.array(var_pickle)[ind:ind + len(ts_persistence)]
-            time_persistence = np.array(time_pickle)[ind:ind + len(ts_persistence)].ravel()
+            var_persistence = np.array(var_pickle)[ind : ind + len(ts_persistence)]
+            time_persistence = np.array(time_pickle)[
+                ind : ind + len(ts_persistence)
+            ].ravel()
         # case that we need to derive the data from two pickle files (changing month during the forecast periode)
         else:
             t_persistence_first_m = []  # should hold dates of the first month
@@ -1067,35 +1466,69 @@ class Postprocess(TrainModel):
                     t_persistence_second_m.append(ts_persistence[t])
             if year_origin == year_start:
                 # Open files to search for the indizes of the corresponding time
-                time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T')
-                time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'T')
+                time_pickle_first = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_start, "T"
+                )
+                time_pickle_second = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_end, "T"
+                )
 
                 # Open file to search for the correspoding meteorological fields
-                var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X')
-                var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'X')
+                var_pickle_first = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_start, "X"
+                )
+                var_pickle_second = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, month_end, "X"
+                )
 
             if year_origin != year_start:
                 # Open files to search for the indizes of the corresponding time
-                time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 1, 'T')
-                time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, 12, 'T')
+                time_pickle_second = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_origin, 1, "T"
+                )
+                time_pickle_first = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, 12, "T"
+                )
 
                 # Open file to search for the correspoding meteorological fields
-                var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 1, 'X')
-                var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, 12, 'X')
+                var_pickle_second = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_origin, 1, "X"
+                )
+                var_pickle_first = Postprocess.load_pickle_for_persistence(
+                    input_dir_pkl, year_start, 12, "X"
+                )
 
             # Retrieve starting index
-            ind_first_m = list(time_pickle_first).index(np.array(t_persistence_first_m[0]))
-            ind_second_m = list(time_pickle_second).index(np.array(t_persistence_second_m[0]))
+            ind_first_m = list(time_pickle_first).index(
+                np.array(t_persistence_first_m[0])
+            )
+            ind_second_m = list(time_pickle_second).index(
+                np.array(t_persistence_second_m[0])
+            )
 
             # append the sequence of the second month to the first month
-            var_persistence = np.concatenate((var_pickle_first[ind_first_m:ind_first_m + len(t_persistence_first_m)],
-                                              var_pickle_second[
-                                              ind_second_m:ind_second_m + len(t_persistence_second_m)]),
-                                             axis=0)
-            time_persistence = np.concatenate((time_pickle_first[ind_first_m:ind_first_m + len(t_persistence_first_m)],
-                                               time_pickle_second[
-                                               ind_second_m:ind_second_m + len(t_persistence_second_m)]),
-                                              axis=0).ravel()
+            var_persistence = np.concatenate(
+                (
+                    var_pickle_first[
+                        ind_first_m : ind_first_m + len(t_persistence_first_m)
+                    ],
+                    var_pickle_second[
+                        ind_second_m : ind_second_m + len(t_persistence_second_m)
+                    ],
+                ),
+                axis=0,
+            )
+            time_persistence = np.concatenate(
+                (
+                    time_pickle_first[
+                        ind_first_m : ind_first_m + len(t_persistence_first_m)
+                    ],
+                    time_pickle_second[
+                        ind_second_m : ind_second_m + len(t_persistence_second_m)
+                    ],
+                ),
+                axis=0,
+            ).ravel()
             # Note: ravel is needed to eliminate the unnecessary dimension (20,1) becomes (20,)
 
         if len(time_persistence.tolist()) == 0:
@@ -1119,15 +1552,21 @@ class Postprocess(TrainModel):
         :param month_start: The year for which data is requested as integer
         :param pkl_type: Either "X" or "T"
         """
-        path_to_pickle = os.path.join(input_dir_pkl, str(year_start), pkl_type + "_{:02}.pkl".format(month_start))
+        path_to_pickle = os.path.join(
+            input_dir_pkl, str(year_start), pkl_type + "_{:02}.pkl".format(month_start)
+        )
         try:
             with open(path_to_pickle, "rb") as pkl_file:
                 var = pickle.load(pkl_file)
                 return var
         except Exception as e:
-            
-            print("The pickle file {} does not generated, please consider re-generate the pickle data in the preprocessing step 1",path_to_pickle)
-            raise(e)
+
+            print(
+                "The pickle file {} does not generated, please consider re-generate the pickle data in the preprocessing step 1",
+                path_to_pickle,
+            )
+            raise (e)
+
     @staticmethod
     def save_ds_to_netcdf(ds, nc_fname, comp_level=5):
         """
@@ -1141,29 +1580,55 @@ class Postprocess(TrainModel):
 
         # sanity checks
         if not isinstance(ds, xr.Dataset):
-            raise ValueError("%{0}: Argument 'ds' must be a xarray dataset.".format(method))
+            raise ValueError(
+                "%{0}: Argument 'ds' must be a xarray dataset.".format(method)
+            )
 
         if not isinstance(comp_level, int):
-            raise ValueError("%{0}: Argument 'comp_level' must be an integer.".format(method))
+            raise ValueError(
+                "%{0}: Argument 'comp_level' must be an integer.".format(method)
+            )
         else:
             if comp_level < 1 or comp_level > 9:
-                raise ValueError("%{0}: Argument 'comp_level' must be an integer between 1 and 9.".format(method))
+                raise ValueError(
+                    "%{0}: Argument 'comp_level' must be an integer between 1 and 9.".format(
+                        method
+                    )
+                )
 
         if not os.path.isdir(os.path.dirname(nc_fname)):
-            raise NotADirectoryError("%{0}: The directory to store the netCDf-file does not exist.".format(method))
+            raise NotADirectoryError(
+                "%{0}: The directory to store the netCDf-file does not exist.".format(
+                    method
+                )
+            )
 
         encode_nc = {key: {"zlib": True, "complevel": comp_level} for key in ds.keys()}
-        
+
         # populate data in netCDF-file (take care for the mode!)
         try:
-            ds.to_netcdf(nc_fname, encoding=encode_nc,engine="netcdf4")
-            print("%{0}: netCDF-file '{1}' was created successfully.".format(method, nc_fname))
+            ds.to_netcdf(nc_fname, encoding=encode_nc, engine="netcdf4")
+            print(
+                "%{0}: netCDF-file '{1}' was created successfully.".format(
+                    method, nc_fname
+                )
+            )
         except Exception as err:
-            print("%{0}: Something unexpected happened when creating netCDF-file '1'".format(method, nc_fname))
+            print(
+                "%{0}: Something unexpected happened when creating netCDF-file '1'".format(
+                    method, nc_fname
+                )
+            )
             raise err
 
     @staticmethod
-    def append_ds(ds_in: xr.Dataset, ds_preexist: xr.Dataset, varnames: list, dim2append: str, dtype=None):
+    def append_ds(
+        ds_in: xr.Dataset,
+        ds_preexist: xr.Dataset,
+        varnames: list,
+        dim2append: str,
+        dtype=None,
+    ):
         """
         Append existing datset with subset of dataset based on selected variables
         :param ds_in: the input dataset from which variables should be retrieved
@@ -1177,33 +1642,56 @@ class Postprocess(TrainModel):
         varnames_str = ",".join(varnames)
         # sanity checks
         if not isinstance(ds_in, xr.Dataset):
-            raise ValueError("%{0}: ds_in must be a xarray dataset, but is of type {1}".format(method, type(ds_in)))
+            raise ValueError(
+                "%{0}: ds_in must be a xarray dataset, but is of type {1}".format(
+                    method, type(ds_in)
+                )
+            )
 
         if not set(varnames).issubset(ds_in.data_vars):
-            raise ValueError("%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(method,
-                                                                                                       varnames_str))
+            raise ValueError(
+                "%{0}: Could not find all variables ({1}) in input dataset ds_in.".format(
+                    method, varnames_str
+                )
+            )
         if dtype is None:
             dtype = np.double
         else:
             if not np.issubdtype(dtype, np.number):
-                raise ValueError("%{0}: dytpe must be a NumPy datatype, but is '{1}'".format(method, np.dtype(dtype)))
-  
+                raise ValueError(
+                    "%{0}: dytpe must be a NumPy datatype, but is '{1}'".format(
+                        method, np.dtype(dtype)
+                    )
+                )
+
         if ds_preexist is None:
             ds_preexist = ds_in[varnames].copy(deep=True)
-            ds_preexist = ds_preexist.astype(dtype)                           # change data type (if necessary)
+            ds_preexist = ds_preexist.astype(dtype)  # change data type (if necessary)
             return ds_preexist
         else:
             if not isinstance(ds_preexist, xr.Dataset):
-                raise ValueError("%{0}: ds_preexist must be a xarray dataset, but is of type {1}"
-                                 .format(method, type(ds_preexist)))
+                raise ValueError(
+                    "%{0}: ds_preexist must be a xarray dataset, but is of type {1}".format(
+                        method, type(ds_preexist)
+                    )
+                )
             if not set(varnames).issubset(ds_preexist.data_vars):
-                raise ValueError("%{0}: Could not find all varibales ({1}) in pre-existing dataset ds_preexist"
-                                 .format(method, varnames_str))
+                raise ValueError(
+                    "%{0}: Could not find all varibales ({1}) in pre-existing dataset ds_preexist".format(
+                        method, varnames_str
+                    )
+                )
 
         try:
-            ds_preexist = xr.concat([ds_preexist, ds_in[varnames].astype(dtype)], dim2append)
+            ds_preexist = xr.concat(
+                [ds_preexist, ds_in[varnames].astype(dtype)], dim2append
+            )
         except Exception as err:
-            print("%{0}: Failed to concat datsets along dimension {1}.".format(method, dim2append))
+            print(
+                "%{0}: Failed to concat datsets along dimension {1}.".format(
+                    method, dim2append
+                )
+            )
             print(ds_in)
             print(ds_preexist)
             raise err
@@ -1221,13 +1709,28 @@ class Postprocess(TrainModel):
         :param nlead_steps: number of forecast steps
         :return: eval_metric_ds
         """
-        eval_metric_dict = dict([("{0}_{1}_{2}".format(varname, *(fcst_prod, eval_met)), (["init_time", "fcst_hour"],
-                                  np.full((nsamples, nlead_steps), np.nan)))
-                                 for eval_met in eval_metrics for fcst_prod in fcst_products])
+        eval_metric_dict = dict(
+            [
+                (
+                    "{0}_{1}_{2}".format(varname, *(fcst_prod, eval_met)),
+                    (
+                        ["init_time", "fcst_hour"],
+                        np.full((nsamples, nlead_steps), np.nan),
+                    ),
+                )
+                for eval_met in eval_metrics
+                for fcst_prod in fcst_products
+            ]
+        )
 
         init_time_dummy = pd.date_range("1900-01-01 00:00", freq="s", periods=nsamples)
-        eval_metric_ds = xr.Dataset(eval_metric_dict, coords={"init_time": init_time_dummy,  # just a placeholder
-                                                              "fcst_hour": np.arange(1, nlead_steps+1)})
+        eval_metric_ds = xr.Dataset(
+            eval_metric_dict,
+            coords={
+                "init_time": init_time_dummy,  # just a placeholder
+                "fcst_hour": np.arange(1, nlead_steps + 1),
+            },
+        )
 
         return eval_metric_ds
 
@@ -1248,64 +1751,150 @@ class Postprocess(TrainModel):
 
 def main():
     parser = argparse.ArgumentParser()
-    parser.add_argument("--results_dir", type=str, default='results',
-                        help="Directory to save the results")
-    parser.add_argument("--checkpoint", help="Directory with checkpoint or checkpoint name (e.g. ${dir}/model-2000)")
-    parser.add_argument("--mode", type=str, choices=['train', 'val', 'test'], default='test',
-                        help='mode for dataset, val or test.')
-    parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch")
+    parser.add_argument(
+        "--results_dir",
+        type=str,
+        default="results",
+        help="Directory to save the results",
+    )
+    parser.add_argument(
+        "--checkpoint",
+        help="Directory with checkpoint or checkpoint name (e.g. ${dir}/model-2000)",
+    )
+    parser.add_argument(
+        "--mode",
+        type=str,
+        choices=["train", "val", "test"],
+        default="test",
+        help="mode for dataset, val or test.",
+    )
+    parser.add_argument(
+        "--batch_size", type=int, default=8, help="number of samples in batch"
+    )
     parser.add_argument("--num_stochastic_samples", type=int, default=1)
-    parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use")
+    parser.add_argument(
+        "--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use"
+    )
     parser.add_argument("--seed", type=int, default=7)
-    parser.add_argument("--evaluation_metrics", "-eval_metrics", dest="eval_metrics", nargs="+",
-                        default=("mse", "psnr", "ssim", "acc", "texture"),
-                        help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.")
-    parser.add_argument("--channel", "-channel", dest="channel", type=int, default=0,
-                        help="Channel which is used for evaluation.")
-    parser.add_argument("--lquick_evaluation", "-lquick", dest="lquick", default=False, action="store_true",
-                        help="Flag if (reduced) quick evaluation based on MSE is performed.")
-    parser.add_argument("--evaluation_metric_quick", "-metric_quick", dest="metric_quick", type=str, default="mse",
-                        help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.")
-    parser.add_argument("--climatology_file", "-clim_fl", dest="clim_fl", type=str, default=False,
-                        help="The path to the climatology_t2m_1991-2020.nc file ")
-    parser.add_argument("--frac_data", "-f_dt",  dest="f_dt", type=float, default=1.,
-                        help="Fraction of dataset to be used for evaluation (only applied when shuffling is active).")
-    parser.add_argument("--test_mode", "-test", dest="test_mode", default=False, action="store_true",
-                        help="Test mode for postprocessing to allow bootstrapping on small datasets.")
+    parser.add_argument(
+        "--evaluation_metrics",
+        "-eval_metrics",
+        dest="eval_metrics",
+        nargs="+",
+        default=("mse", "psnr", "ssim", "acc", "texture"),
+        help="Metrics to be evaluate the trained model. Must be known metrics, see Scores-class.",
+    )
+    parser.add_argument(
+        "--channel",
+        "-channel",
+        dest="channel",
+        type=int,
+        default=0,
+        help="Channel which is used for evaluation.",
+    )
+    parser.add_argument(
+        "--lquick_evaluation",
+        "-lquick",
+        dest="lquick",
+        default=False,
+        action="store_true",
+        help="Flag if (reduced) quick evaluation based on MSE is performed.",
+    )
+    parser.add_argument(
+        "--evaluation_metric_quick",
+        "-metric_quick",
+        dest="metric_quick",
+        type=str,
+        default="mse",
+        help="(Only) metric to evaluate when quick evaluation (-lquick) is chosen.",
+    )
+    parser.add_argument(
+        "--climatology_file",
+        "-clim_fl",
+        dest="clim_fl",
+        type=str,
+        default=False,
+        help="The path to the climatology_t2m_1991-2020.nc file ",
+    )
+    parser.add_argument(
+        "--frac_data",
+        "-f_dt",
+        dest="f_dt",
+        type=float,
+        default=1.0,
+        help="Fraction of dataset to be used for evaluation (only applied when shuffling is active).",
+    )
+    parser.add_argument(
+        "--test_mode",
+        "-test_mode",
+        dest="test_mode",
+        default=False,
+        action="store_true",
+        help="Test mode for postprocessing to allow bootstrapping on small datasets.",
+    )
     args = parser.parse_args()
 
     method = os.path.basename(__file__)
 
-    print('----------------------------------- Options ------------------------------------')
+    print(
+        "----------------------------------- Options ------------------------------------"
+    )
     for k, v in args._get_kwargs():
         print(k, "=", v)
-    print('------------------------------------- End --------------------------------------')
+    print(
+        "------------------------------------- End --------------------------------------"
+    )
 
     eval_metrics = args.eval_metrics
     results_dir = args.results_dir
-    if args.lquick:      # in case of quick evaluation, onyl evaluate MSE and modify results_dir
+    if (
+        args.lquick
+    ):  # in case of quick evaluation, onyl evaluate MSE and modify results_dir
         eval_metrics = [args.metric_quick]
-        if not glob.glob(os.path.join(args.checkpoint,"*.meta")):
-            print(os.path.join(args.checkpoint,"*.meta"))
-            raise ValueError("%{0}: Pass a specific checkpoint-file for quick evaluation.".format(method))
+        if not glob.glob(os.path.join(args.checkpoint, "*.meta")):
+            print(os.path.join(args.checkpoint, "*.meta"))
+            raise ValueError(
+                "%{0}: Pass a specific checkpoint-file for quick evaluation.".format(
+                    method
+                )
+            )
         chp = os.path.basename(args.checkpoint)
         results_dir = args.results_dir + "_{0}".format(chp)
-        print("%{0}: Quick evaluation is chosen. \n * evaluation metric: {1}\n".format(method, args.metric_quick) +
-              "* checkpointed model: {0}\n * no conditional quantile and forecast example plots".format(chp))
+        print(
+            "%{0}: Quick evaluation is chosen. \n * evaluation metric: {1}\n".format(
+                method, args.metric_quick
+            )
+            + "* checkpointed model: {0}\n * no conditional quantile and forecast example plots".format(
+                chp
+            )
+        )
 
     # initialize postprocessing instance
-    postproc_instance = Postprocess(results_dir=results_dir, checkpoint=args.checkpoint, data_mode="test",
-                                    batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples,
-                                    gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args,
-                                    eval_metrics=eval_metrics, channel=args.channel, lquick=args.lquick,
-                                    clim_path=args.clim_fl,frac_data=args.frac_data, ltest=args.test_mode)
+    postproc_instance = Postprocess(
+        results_dir=results_dir,
+        checkpoint=args.checkpoint,
+        data_mode="test",
+        batch_size=args.batch_size,
+        num_stochastic_samples=args.num_stochastic_samples,
+        gpu_mem_frac=args.gpu_mem_frac,
+        seed=args.seed,
+        args=args,
+        eval_metrics=eval_metrics,
+        channel=args.channel,
+        lquick=args.lquick,
+        clim_path=args.clim_fl,
+        frac_data=args.f_dt,
+        ltest=args.test_mode,
+    )
     # run the postprocessing
     postproc_instance.run()
     postproc_instance.handle_eval_metrics()
-    if not args.lquick:    # don't produce additional plots in case of quick evaluation
-        postproc_instance.plot_example_forecasts(metric=args.eval_metrics[0], channel=args.channel)
+    if not args.lquick:  # don't produce additional plots in case of quick evaluation
+        postproc_instance.plot_example_forecasts(
+            metric=args.eval_metrics[0], channel=args.channel
+        )
         postproc_instance.plot_conditional_quantiles()
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
index eb69a74045ffad93502afbdb1aac8fa20b593294..f156c133e5b65e364b4d24cb3b610f495a31e668 100644
--- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
+++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
@@ -115,6 +115,7 @@ class ERA5Dataset(object):
         self.tf_names = []
         for year, months in self.data_mode.items():
             for month in months:
+                print("year",year,"month",month)
                 tf_files = "sequence_Y_{}_M_{}_*_to_*.tfrecord*".format(year,month)    
                 self.tf_names.append(tf_files)
         # look for tfrecords in input_dir and input_dir/mode directories
diff --git a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
index 925f94c44156f86fe43f168e63b3bb5935c2db04..79d7653a5cc55d72abb7ea2fbdb22ca8be4c3b67 100644
--- a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
+++ b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py
@@ -6,7 +6,6 @@
 """
 
 import tensorflow as tf
-import numpy as np
 weight_decay = 0.0005
 
 def _activation_summary(x):
@@ -67,7 +66,7 @@ def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.co
                                                                  input_channels, num_features],
                                               stddev = 0.01, wd = weight_decay)
         biases = _variable_on_gpu('biases', [num_features], initializer)
-        conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding = 'SAME')
+        conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding='SAME')
         conv_biased = tf.nn.bias_add(conv, biases)
         if activate == "linear":
             return conv_biased
@@ -162,4 +161,3 @@ def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None):
 
 def bn_layers_wrapper(inputs, is_training):
     pass
-   
diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
index a4b60965d6e03a7ccbb4197c3c6c237944a101c5..290def9f7c934f871fedd8c1703bbc114c822dcd 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py
@@ -14,8 +14,7 @@ from .mcnet_model import McNetVideoPredictionModel
 from .test_model import TestModelVideoPredictionModel
 from model_modules.model_architectures import known_models
 from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel
-
-
+from .weatherBench3DCNN import WeatherBenchModel
 
 def get_model_class(model):
     model_mappings = known_models()
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
index 50364b52d2dd13b48bb3087abcc15e147ee1cfd1..bd3795ebecbf750bd115a430ee1adce64074708a 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py
@@ -153,6 +153,11 @@ class VanillaConvLstmVideoPredictionModel(object):
 
     @staticmethod
     def convLSTM_cell(inputs, hidden):
+        """
+        SPDX-FileCopyrightText: loliverhennigh 
+        SPDX-License-Identifier: Apache-2.0
+        The following function was revised based on the github https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow 
+        """
         y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset
         channels = inputs.get_shape()[-1]
         # conv lstm cell
diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c3563bc9fbe4709f4f87f4d5d6905269c949ba8
--- /dev/null
+++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
+#
+# SPDX-License-Identifier: MIT
+# Weather Bench models
+__email__ = "b.gong@fz-juelich.de"
+__author__ = "Bing Gong"
+__date__ = "2021-04-13"
+
+import tensorflow as tf
+from tensorflow.contrib.training import HParams
+from model_modules.video_prediction.layers import layer_def as ld
+from model_modules.video_prediction.losses import  *
+
+class WeatherBenchModel(object):
+
+    def __init__(self, hparams_dict=None, mode="train",**kwargs):
+        """
+        This is class for building weahterBench architecture by using updated hparameters
+        args:
+             mode        :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model
+             hparams_dict: dict, the dictionary contains the hparaemters names and values
+        """
+        self.hparams_dict = hparams_dict
+        self.mode = mode
+        self.hparams = self.parse_hparams()
+        self.learning_rate = self.hparams.lr
+        self.filters = self.hparams.filters
+        self.kernels = self.hparams.kernels
+        self.max_epochs = self.hparams.max_epochs
+        self.batch_size = self.hparams.batch_size
+        self.outputs = {}
+        self.total_loss = None
+
+    def get_default_hparams(self):
+        return HParams(**self.get_default_hparams_dict())
+
+    def parse_hparams(self):
+        """
+        Parse the hparams setting to ovoerride the default ones
+        """
+
+        parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
+        return parsed_hparams
+
+
+    def get_default_hparams_dict(self):
+        """
+        The function that contains default hparams
+        Returns:
+            A dict with the following hyperparameters.
+            context_frames  : the number of ground-truth frames to pass in at start.
+            max_epochs      : the number of epochs to train model
+            lr              : learning rate
+            loss_fun        : the loss function
+            filters         : list contains the filters of each convolutional layer
+            kernels         : list contains the kernels size for each convolutional layer
+            """
+        hparams = dict(
+            sequence_length =13,
+            context_frames =1,
+            max_epochs = 20,
+            batch_size = 40,
+            lr = 0.001,
+            shuffle_on_val= True,
+            filters = [64, 64, 64, 64, 3],
+            kernels = [5, 5, 5, 5, 5]
+        )
+        return hparams
+
+
+    def build_graph(self, x):
+        self.is_build_graph = False
+        self.x = x["images"]
+
+        self.global_step = tf.train.get_or_create_global_step()
+        original_global_variables = tf.global_variables()
+
+        # Architecture
+        x_hat = self.build_model(self.x[:,0,:, :, :],self.filters, self.kernels)
+        # Loss
+        
+        self.total_loss = l1_loss(self.x[:,1,:, :,:], x_hat[:,:,:,:])
+
+        # Optimizer
+        self.train_op = tf.train.AdamOptimizer(
+            learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
+
+        # outputs
+        self.outputs["total_loss"] = self.total_loss
+       
+        # inferences
+        if self.mode == "test":
+            self.outputs["gen_images"] = self.forecast(self.x, 12, self.filters, self.kernels)
+        else:
+            self.outputs["gen_images"] = x_hat
+
+        # Summary op
+        tf.summary.scalar("total_loss", self.total_loss)
+        self.summary_op = tf.summary.merge_all()
+        global_variables = [var for var in tf.global_variables() if var not in original_global_variables]
+        self.saveable_variables = [self.global_step] + global_variables
+        self.is_build_graph = True
+        return self.is_build_graph
+
+
+    def build_model(self, x, filters, kernels):
+        """Fully convolutional network"""
+        idx = 0 
+        for f, k in zip(filters[:-1], kernels[:-1]):
+            with tf.variable_scope("conv_layer_"+str(idx),reuse=tf.AUTO_REUSE):
+                x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu")
+            idx += 1
+        with tf.variable_scope("Conv_last_layer",reuse=tf.AUTO_REUSE):
+            output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear")
+        return output
+
+
+    def forecast(self, x, forecast_time, filters, kernels):
+        x_hat = []
+
+        for i in range(forecast_time):
+            if i == 0:
+                x_pred = self.build_model(x[:,i,:, :,:],filters,kernels)
+            else:
+                x_pred = self.build_model(x_pred,filters,kernels)
+            x_hat.append(x_pred)
+
+        x_hat = tf.stack(x_hat)
+        x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4])
+        return x_hat