diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a0299e121bebf0b7df05492be71ad0a877c0703
--- /dev/null
+++ b/README.md
@@ -0,0 +1,8 @@
+# Benchmark dataset for statistical downscaling with deep neural networks
+
+This repository aims to provide a benchmark dataset for statistical downscaling of meteorological fields with deep neural networks.
+The work is pursued in scope of the [MAELSTROM project](https://www.maelstrom-eurohpc.eu/) which aims to develop new machine learning applications for weather and climate under the cooridniantion of ECMWF. <br>
+The benchmark dataset is based on two reanalysis datasets that are well established and quality-controlled in meteorology: The ERA5-reanalysis data serves as the input dataset, whereas the COSMO-REA6 datasets provides the target data for the downscaling. With a grid spacing of 0.275° of the input data compared to 0.0625° of the target data, a downscaling factor equals to 4 in the benchmark dataset. Furthermore, different specific downscaling tasks adapted from the literature are provided. These pertain the following meteorlogical quantities:
+- 2m temerature 
+- 10m wind 
+- solar irradiance 
diff --git a/downscaling_ap5/config/config_ds_tier2.json b/downscaling_ap5/config/config_ds_tier2.json
index e47fc5f6efb8d97d5a9b72ca54d930b92312373c..6c63b7da6e0045c59410740606c60c05ec75d997 100644
--- a/downscaling_ap5/config/config_ds_tier2.json
+++ b/downscaling_ap5/config/config_ds_tier2.json
@@ -1,4 +1,5 @@
 {
+  "num_files": 33,
   "norm_dims": ["time", "rlat", "rlon"],
   "batch_size": 32,
   "var_tar2in": "hsurf_tar",
diff --git a/downscaling_ap5/config/config_ds_tier2_wind.json b/downscaling_ap5/config/config_ds_tier2_wind.json
index 9591a95ef1502d85659e42f285a8ff5b2bce4ef7..5130555bd702ebeaaa206662c4455fd289f6bd27 100644
--- a/downscaling_ap5/config/config_ds_tier2_wind.json
+++ b/downscaling_ap5/config/config_ds_tier2_wind.json
@@ -2,5 +2,6 @@
   "norm_dims": ["time", "rlat", "rlon"],
   "batch_size": 32,
   "var_tar2in": "hsurf_tar",
-  "predictands": ["u_10m_tar", "v_10m_tar", "hsurf_tar"]
+  "predictands": ["u_10m_tar", "v_10m_tar", "hsurf_tar"],
+  "recon_loss": "mae_vec"
 }
diff --git a/downscaling_ap5/config/config_wgan.json b/downscaling_ap5/config/config_wgan.json
index c9492c9eedd662e1023e4869ab9c1acf1d56a3c9..ab9a0958798f823a51dd7ec37d0e1582d9680285 100644
--- a/downscaling_ap5/config/config_wgan.json
+++ b/downscaling_ap5/config/config_wgan.json
@@ -1,11 +1,8 @@
 {
   "nepochs": 60,
   "d_steps": 5,
-  "z_branch": true,
-  "lr_gen": 5e-05,
-  "lr_critic": 1e-06,
-  "lr_gen_end": 5e-06,
   "lr_decay": true,
-  "named_targets": false
-
+  "named_targets": false,
+  "hparams_generator": {"lr": 5e-05, "lr_end": 5e-06, "z_branch": true, "activation": "swish", "l_avgpool": true},
+  "hparams_critic": {"lr": 1e-06, "activation": "swish"}
 }
diff --git a/downscaling_ap5/config/config_wgan_test.json b/downscaling_ap5/config/config_wgan_test.json
index ec40059eb701b6a8076702b6b29d5cbe15622c6e..9c195b852f183159c71c51323ef0f46d63c7b504 100644
--- a/downscaling_ap5/config/config_wgan_test.json
+++ b/downscaling_ap5/config/config_wgan_test.json
@@ -1,11 +1,8 @@
 {
   "nepochs": 5,
   "d_steps": 5,
-  "z_branch": true,
-  "lr_gen": 5e-05,
-  "lr_critic": 1e-06,
-  "lr_gen_end": 5e-06,
   "lr_decay": true,
-  "named_targets": false
-
+  "named_targets": false,
+  "hparams_generator": {"lr": 5e-05, "lr_end": 5e-06, "z_branch": true, "activation": "swish", "l_avgpool": true},
+  "hparams_critic": {"lr": 1e-06, "activation": "swish"}
 }
diff --git a/downscaling_ap5/env_setup/modules_jsc.sh b/downscaling_ap5/env_setup/modules_jsc.sh
index 15cbca28e674b16c87d5cc6c397ee8c282ce9a0b..88b43745487082603d8b53453a97e411b23318f5 100755
--- a/downscaling_ap5/env_setup/modules_jsc.sh
+++ b/downscaling_ap5/env_setup/modules_jsc.sh
@@ -20,9 +20,11 @@ if [[ 0 == 0 ]]; then  # Restoring from model collection currently throws MPI-se
   ml GCC/11.2.0
   ml ParaStationMPI/5.5.0-1
   ml mpi4py/3.1.3
+  ml tqdm/4.62.3
   ml CDO/2.0.2
   ml NCO/5.0.3
   ml netcdf4-python/1.5.7-serial
+  ml scikit-image/0.18.3
   ml SciPy-bundle/2021.10
   ml xarray/0.20.1
   ml dask/2021.9.1
diff --git a/downscaling_ap5/grid_des/crea6_reg_grid b/downscaling_ap5/grid_des/crea6_reg_grid
new file mode 100644
index 0000000000000000000000000000000000000000..2c1dbc9a71d9727a7746cdd3df80fc8aad35e082
--- /dev/null
+++ b/downscaling_ap5/grid_des/crea6_reg_grid
@@ -0,0 +1,18 @@
+#
+# Regular CREA6 grid for downscaling task
+#
+gridtype  = lonlat
+gridsize  = 93312
+xsize     = 432
+ysize     = 216
+xname     = lon
+xlongname = "longitude"
+xunits    = "degrees"
+yname     = lat
+ylongname = "latitude"
+yunits    = "degrees"
+xfirst    = -1.25
+xinc      = 0.0625
+yfirst    = 42.3125
+yinc      = 0.0625
+
diff --git a/downscaling_ap5/handle_data/handle_data_class.py b/downscaling_ap5/handle_data/handle_data_class.py
index 4849de9bf506ca5e3b46d9e892d1ef38b56d765e..03b5030d5e160659330de129a2af026026619841 100644
--- a/downscaling_ap5/handle_data/handle_data_class.py
+++ b/downscaling_ap5/handle_data/handle_data_class.py
@@ -5,7 +5,7 @@
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-01-20"
-__update__ = "2023-04-17"
+__update__ = "2023-08-18"
 
 import os, glob
 from typing import List
@@ -27,7 +27,7 @@ try:
 except:
     from multiprocessing.pool import ThreadPool
 from all_normalizations import ZScore
-from other_utils import to_list, find_closest_divisor, free_mem
+from other_utils import to_list, find_closest_divisor
 
 
 class HandleDataClass(object):
@@ -321,7 +321,10 @@ class HandleDataClass(object):
             data_iter = data_iter.repeat()
 
         # clean-up to free some memory
+        # free_mem([da, da_in, da_tar, varnames_tar])
         del da
+        del da_in
+        del da_tar
         gc.collect()
 
         return data_iter
@@ -383,6 +386,11 @@ def get_dataset_filename(datadir: str, dataset_name: str, subset: str, laugmente
         if subset == "train":
             fname_suffix = f"{fname_suffix}*"
         if laugmented: raise ValueError("No augmented dataset available for Tier-2.")
+    elif dataset_name == "atmorep":
+        fname_suffix = f"{fname_suffix}_{dataset_name}_{subset}"
+        if subset == "train":
+            fname_suffix = f"{fname_suffix}*"
+        if laugmented: raise ValueError("No augmented dataset available for AtmoRep.")
     else:
         raise ValueError(f"Unknown dataset '{dataset_name}' passed.")
 
@@ -431,7 +439,7 @@ class StreamMonthlyNetCDF(object):
         self.predictor_list = selected_predictors
         self.predictand_list = selected_predictands
         self.n_predictands, self.n_predictors = len(self.predictand_list), len(self.predictor_list)
-        self.all_vars = self.predictor_list + self.predictand_list
+        self.all_vars = self.predictor_list + self.predictand_list       # ordering important to ensure that predictors come first!
         self.ds_all = xr.open_mfdataset(list(self.file_list), decode_cf=False, cache=False)  # , parallel=True)
         self.var_tar2in = var_tar2in
         if self.var_tar2in is not None:
@@ -463,7 +471,7 @@ class StreamMonthlyNetCDF(object):
         return self.nsamples
 
     def getitems(self, indices):
-        da_now = self.data_now.isel({self.sample_dim: indices}).to_array("variables")
+        da_now = self.data_now.isel({self.sample_dim: indices}).to_array("variables").sel({"variables": self.all_vars})
         if self.var_tar2in is not None:
             # NOTE: * The order of the following operation must be the same as in make_tf_dataset_allmem
             #       * The following operation order must concatenate var_tar2in by da_in to ensure
@@ -600,7 +608,7 @@ class StreamMonthlyNetCDF(object):
             if all(stat_list):
                 selected_vars = var_list
             else:
-                miss_inds = [i for i, x in enumerate(stat_list) if x]
+                miss_inds = [i for i, x in enumerate(stat_list) if not x]
                 miss_vars = [var_list[i] for i in miss_inds]
                 raise ValueError(f"Could not find the following variables in the dataset: {*miss_vars,}")
 
@@ -660,8 +668,6 @@ class StreamMonthlyNetCDF(object):
             data_now = xr.concat([data_now, ds_add], dim=self.sample_dim)
             print(f"Appending data with {add_samples:d} samples took {timer() - t1:.2f}s" +
                   f"(total #samples: {data_now.dims[self.sample_dim]})")
-            # free memory
-            free_mem([ds_add, add_samples, istart])
 
         self.data_loaded[il] = data_now
         # timing
@@ -670,8 +676,6 @@ class StreamMonthlyNetCDF(object):
         self.ds_proc_size += data_now.nbytes
         print(f"Dataset #{set_ind:d} ({il+1:d}/2) reading time: {t_read:.2f}s.")
         self.iload_next = il + 1
-        # free memory
-        free_mem([nsamples, t_read, data_now])
 
         return il
 
diff --git a/downscaling_ap5/main_scripts/main_download_era5.py b/downscaling_ap5/main_scripts/main_download_era5.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffb04ac9d73ebe11ffbb96319f6483d3f90a7db3
--- /dev/null
+++ b/downscaling_ap5/main_scripts/main_download_era5.py
@@ -0,0 +1,75 @@
+# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
+#
+# SPDX-License-Identifier: MIT
+
+"""
+Script to download ERA5 data from the CDS API.
+"""
+
+__author__ = "Michael Langguth"
+__email__ = "m.langguth@fz-juelich.de"
+__date__ = "2023-11-22"
+__update__ = "2023-08-22"
+
+# import modules
+import os, sys
+import json as js
+import logging
+import argparse
+from download_era5_data import ERA5_Data_Loader
+
+# get logger
+logger = logging.getLogger(os.path.basename(__file__).rstrip(".py"))
+logger.setLevel(logging.INFO)
+formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s')
+
+def main(parser_args):
+
+    data_dir = parser_args.data_dir
+    
+    # read configuration files for model and dataset
+    with parser_args.data_req_file as fdreq:
+        req_dict = js.load(fdreq)
+
+
+    # create output directory
+    if not os.path.exists(data_dir):
+        os.makedirs(data_dir)
+
+    # create logger handlers
+    logfile = os.path.join(data_dir, f"download_era5_{parser_args.exp_name}.log")
+    if os.path.isfile(logfile): os.remove(logfile)
+    fh = logging.FileHandler(logfile)
+    ch = logging.StreamHandler(sys.stdout)
+    ch.setLevel(logging.DEBUG)
+    fh.setLevel(logging.INFO)
+
+    fh.setFormatter(formatter)
+    ch.setFormatter(formatter)
+
+    logger.addHandler(fh), logger.addHandler(ch)
+
+    # create data loader instance
+    data_loader = ERA5_Data_Loader(parser_args.nworkers)
+
+    # download data
+    _ = data_loader(req_dict, data_dir, parser_args.start, parser_args.end, parser_args.format)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--data_directory", "-data_dir", dest="data_dir", type=str, required=True,
+                        help="Directory where test dataset (netCDF-file) is stored.")
+    parser.add_argument("--data_request_file", "-data_req_file", dest="data_req_file", type=argparse.FileType("r"), required=True,
+                        help="File containing data request information for the CDS API.")
+    parser.add_argument("--year_start", "-start", dest="start", type=int, default=1995, 
+                        help="Start year of ERA5-data request.")
+    parser.add_argument("--year_end", "-end", dest="end", type=int, default=2019,
+                        help="End year of ERA5-data request.")
+    parser.add_argument("--nworkers", "-nw", dest="nowrkers", type=int, default=4,
+                        help="Number of workers to download ERA5 data.")
+    parser.add_argument("--format", "-format", dest="format", type=str, default="netcdf",
+                        help="Format of downloaded data.")
+
+    args = parser.parse_args()
+    main(args)
\ No newline at end of file
diff --git a/downscaling_ap5/main_scripts/main_postprocess.py b/downscaling_ap5/main_scripts/main_postprocess.py
index 0e60fbd7fbf19a05c6e0c88cb65bc4b66cceca43..e9735171e6fc7050467e8691d7cb49f39d845fe7 100644
--- a/downscaling_ap5/main_scripts/main_postprocess.py
+++ b/downscaling_ap5/main_scripts/main_postprocess.py
@@ -9,23 +9,28 @@ Driver-script to perform inference on trained downscaling models.
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-12-08"
-__update__ = "2022-12-08"
+__update__ = "2023-08-21"
 
 import os, sys, glob
 import logging
 import argparse
 from timeit import default_timer as timer
 import json as js
+from datetime import datetime as dt
+#import datetime as dt
+import gc
 import numpy as np
 import xarray as xr
 import tensorflow.keras as keras
 import matplotlib as mpl
+import cartopy.crs as ccrs
 from handle_data_unet import *
 from handle_data_class import HandleDataClass, get_dataset_filename
 from all_normalizations import ZScore
 from statistical_evaluation import Scores
-from postprocess import get_model_info, run_evaluation_time, run_evaluation_spatial
-from datetime import datetime as dt
+from postprocess import get_model_info, run_evaluation_time, run_evaluation_spatial, run_feature_importance
+from model_utils import convert_to_xarray
+#from other_utils import free_mem
 
 # get logger
 logger = logging.getLogger(os.path.basename(__file__).rstrip(".py"))
@@ -33,6 +38,7 @@ logger.setLevel(logging.INFO)
 formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s')
 
 
+
 def main(parser_args):
 
     t0 = timer()
@@ -59,7 +65,8 @@ def main(parser_args):
 
     logger.addHandler(fh), logger.addHandler(ch)
     
-    logger.info(f"Start postprocessing at {dt.now().strftime('%Y-%m-%d %H:%M:%S')}")
+    #logger.info(f"Start postprocessing at {dt.now().strftime('%Y-%m-%d %H:%M:%S')}")
+    logger.info(f"Start postprocessing at...")
 
     # read configuration files
     md_config_pattern, ds_config_pattern = f"config_{model_type}.json", f"config_ds_{parser_args.dataset}.json"
@@ -105,41 +112,46 @@ def main(parser_args):
     logger.info(f"Variable {tar_varname} serves as ground truth data.")
 
     with xr.open_dataset(fdata_test) as ds_test:
-        ground_truth = ds_test[tar_varname].astype("float32", copy=False)
+        ground_truth = ds_test[tar_varname].astype("float32", copy=True)
         ds_test = norm.normalize(ds_test)
 
     # prepare training and validation data
     logger.info(f"Start preparing test dataset...")
     t0_preproc = timer()
 
-    da_test = HandleDataClass.reshape_ds(ds_test.astype("float32", copy=False))
-    tfds_test = HandleDataClass.make_tf_dataset_allmem(da_test.astype("float32", copy=True), ds_dict["batch_size"],
-                                                       ds_dict["predictands"], lshuffle=False, var_tar2in=ds_dict["var_tar2in"],
-                                                       named_targets=named_targets, lrepeat=False, drop_remainder=False)
+    da_test = HandleDataClass.reshape_ds(ds_test).astype("float32", copy=True)
+    
+    # clean-up to reduce memory footprint
+    del ds_test
+    gc.collect()
+    #free_mem([ds_test])
+
+    tfds_opts = {"batch_size": ds_dict["batch_size"], "predictands": ds_dict["predictands"], "predictors": ds_dict.get("predictors", None),
+                "lshuffle": False, "var_tar2in": ds_dict["var_tar2in"], "named_targets": named_targets, "lrepeat": False, "drop_remainder": False}    
 
-    # perform normalization
-    da_test_in, da_test_tar = HandleDataClass.split_in_tar(da_test, predictands=ds_dict["predictands"])
+    tfds_test = HandleDataClass.make_tf_dataset_allmem(da_test, **tfds_opts)
+    
+    predictors = ds_dict.get("predictors", None)
+    if predictors is None:
+        predictors = [var for var in list(da_test["variables"].values) if var.endswith("_in")]
+        if ds_dict.get("var_tar2in", False): predictors.append(ds_dict["var_tar2in"])
 
     # start inference
     logger.info(f"Preparation of test dataset finished after {timer() - t0_preproc:.2f}s. " +
                  "Start inference on trained model...")
     t0_train = timer()
-    y_pred_trans = trained_model.predict(tfds_test, verbose=2)
+    y_pred = trained_model.predict(tfds_test, verbose=2)
 
     logger.info(f"Inference on test dataset finished. Start denormalization of output data...")
-    # get coordinates and dimensions from target data
-    slice_dict = {"variables": 0} if hparams_dict["z_branch"] else {}
-    coords = da_test_tar.isel(slice_dict).squeeze().coords
-    dims = da_test_tar.isel(slice_dict).squeeze().dims
-    if hparams_dict["z_branch"]:
-        # slice data to get first channel only
-        if isinstance(y_pred_trans, list): y_pred_trans = y_pred_trans[0]
-        y_pred = xr.DataArray(y_pred_trans[..., 0].squeeze(), coords=coords, dims=dims)
-    else:
-        # no slicing required
-        y_pred = xr.DataArray(y_pred_trans.squeeze(), coords=coords, dims=dims)
-    # perform denormalization
-    y_pred = norm.denormalize(y_pred.squeeze(), varname=tar_varname)
+    
+    # clean-up to reduce memory footprint
+    del tfds_test
+    gc.collect()
+    #free_mem([tfds_test])
+
+    # convert to xarray
+    y_pred = convert_to_xarray(y_pred, norm, tar_varname, da_test.sel({"variables": tar_varname}).squeeze().coords,
+                               da_test.sel({"variables": tar_varname}).squeeze().dims, hparams_dict["z_branch"])
 
     # write inference data to netCDf
     logger.info(f"Write inference data to netCDF-file '{ncfile_out}'")
@@ -155,7 +167,7 @@ def main(parser_args):
 
     logger.info("Start temporal evaluation...")
     t0_tplot = timer()
-    _ = run_evaluation_time(score_engine, "rmse", "K", plt_dir, value_range=(0., 3.), model_type=model_type)
+    rmse_all = run_evaluation_time(score_engine, "rmse", "K", plt_dir, value_range=(0., 3.), model_type=model_type)
     _ = run_evaluation_time(score_engine, "bias", "K", plt_dir, value_range=(-1., 1.), ref_line=0.,
                             model_type=model_type)
     _ = run_evaluation_time(score_engine, "grad_amplitude", "1", plt_dir, value_range=(0.7, 1.1),
@@ -163,19 +175,43 @@ def main(parser_args):
 
     logger.info(f"Temporal evalutaion finished in {timer() - t0_tplot:.2f}s.")
 
+    # run feature importance analysis for RMSE
+    logger.info("Start feature importance analysis...")
+    t0_fi = timer()
+
+    rmse_ref = rmse_all.mean().values
+
+    _ = run_feature_importance(da_test, predictors, tar_varname, trained_model, norm, "rmse", rmse_ref,
+                               tfds_opts, plt_dir, patch_size=(6, 6), variable_dim="variables")
+    
+    logger.info(f"Feature importance analysis finished in {timer() - t0_fi:.2f}s.")
+    
+    # clean-up to reduce memory footprint
+    del da_test
+    gc.collect()
+    #free_mem([da_test])
+
     # instantiate score engine with retained spatial dimensions
     score_engine = Scores(y_pred, ground_truth, [])
 
+    # ad-hoc adaption to projection basaed on norm_dims
+    if "rlat" in ds_dict["norm_dims"]:
+        proj=ccrs.RotatedPole(pole_longitude=-162.0, pole_latitude=39.25)
+    else:
+        proj=ccrs.PlateCarree()
+
     logger.info("Start spatial evaluation...")
     lvl_rmse = np.arange(0., 3.1, 0.2)
     cmap_rmse = mpl.cm.afmhot_r(np.linspace(0., 1., len(lvl_rmse)))
-    _ = run_evaluation_spatial(score_engine, "rmse", os.path.join(plt_dir, "rmse_spatial"), cmap=cmap_rmse,
-                               levels=lvl_rmse)
+    _ = run_evaluation_spatial(score_engine, "rmse", os.path.join(plt_dir, "rmse_spatial"), 
+                               dims=ds_dict["norm_dims"][1::], cmap=cmap_rmse, levels=lvl_rmse,
+                               projection=proj)
 
     lvl_bias = np.arange(-2., 2.1, 0.1)
     cmap_bias = mpl.cm.seismic(np.linspace(0., 1., len(lvl_bias)))
-    _ = run_evaluation_spatial(score_engine, "bias", os.path.join(plt_dir, "bias_spatial"), cmap=cmap_bias,
-                               levels=lvl_bias)
+    _ = run_evaluation_spatial(score_engine, "bias", os.path.join(plt_dir, "bias_spatial"), 
+                               dims=ds_dict["norm_dims"][1::], cmap=cmap_bias, levels=lvl_bias,
+                               projection=proj)
 
     logger.info(f"Spatial evalutaion finished in {timer() - t0_tplot:.2f}s.")
 
diff --git a/downscaling_ap5/main_scripts/main_train.py b/downscaling_ap5/main_scripts/main_train.py
index 66970190c0e00b826422ada7c7ae0ec09ff5b4a4..04521802e9f79dc6b39da66ae46bd6e87d79a7f8 100644
--- a/downscaling_ap5/main_scripts/main_train.py
+++ b/downscaling_ap5/main_scripts/main_train.py
@@ -9,7 +9,7 @@ Driver-script to train downscaling models.
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-10-06"
-__update__ = "2023-04-17"
+__update__ = "2023-08-18"
 
 import os
 import argparse
@@ -24,16 +24,13 @@ from tensorflow.keras.utils import plot_model
 from all_normalizations import ZScore
 from model_utils import ModelEngine, TimeHistory, handle_opt_utils, get_loss_from_history
 from handle_data_class import HandleDataClass, get_dataset_filename
-from other_utils import free_mem, print_gpu_usage, print_cpu_usage, copy_filelist
-from benchmark_utils import BenchmarkCSV, get_training_time_dict
+from other_utils import print_gpu_usage, print_cpu_usage, copy_filelist
+from benchmark_utils import get_training_time_dict
 
 
 # Open issues:
 # * d_steps must be parsed with hparams_dict as model is uninstantiated at this point and thus no default parameters
 #   are available
-# * flag named_targets must be set to False in hparams_dict for WGAN to work with U-Net 
-# * ensure that dataset defaults are set (such as predictands for WGAN)
-# * replacement of manual benchmarking by JUBE-benchmarking to avoid duplications
 
 def main(parser_args):
     # start timing
@@ -67,9 +64,6 @@ def main(parser_args):
         data_norm, write_norm = None, True
         norm_dims = ds_dict["norm_dims"]
 
-    # initialize benchmarking object
-    bm_obj = BenchmarkCSV(os.path.join(os.getcwd(), f"benchmark_training_{parser_args.model}.csv"))
-
     # get model instance and set-up batch size
     model_instance = ModelEngine(parser_args.model)
     # Note: bs_train is introduced to allow substepping in the training loop, e.g. for WGAN where n optimization steps
@@ -90,7 +84,7 @@ def main(parser_args):
     # the dataset
     if "*" in fname_or_patt_train:
         ds_obj, tfds_train = HandleDataClass.make_tf_dataset_dyn(datadir, fname_or_patt_train, bs_train, nepochs,
-                                                                 30, ds_dict["predictands"],
+                                                                 ds_dict["num_files"], ds_dict["predictands"],
                                                                  predictors=ds_dict.get("predictors", None),
                                                                  var_tar2in=ds_dict["var_tar2in"],
                                                                  named_targets=named_targets,
@@ -100,7 +94,11 @@ def main(parser_args):
         tfds_train_size = ds_obj.dataset_size
     else:
         ds_train = xr.open_dataset(fname_or_patt_train)
-        da_train = HandleDataClass.reshape_ds(ds_train.astype("float32", copy=False))
+        da_train = HandleDataClass.reshape_ds(ds_train).astype("float32", copy=True)
+
+        # free up some memory
+        del ds_train
+        gc.collect()
 
         if not data_norm:
             # data_norm must be freshly instantiated (triggering later parameter retrieval)
@@ -111,9 +109,14 @@ def main(parser_args):
                                                             predictors=ds_dict.get("predictors", None),
                                                             var_tar2in=ds_dict["var_tar2in"],
                                                             named_targets=named_targets)
+        
         nsamples, shape_in = da_train.shape[0], tfds_train.element_spec[0].shape[1:].as_list()
         tfds_train_size = da_train.nbytes
 
+        # clean up to save some memory
+        del da_train
+        gc.collect()
+
     if write_norm:
         data_norm.save_norm_to_file(os.path.join(model_savedir, "norm.json"))
 
@@ -125,15 +128,17 @@ def main(parser_args):
     fdata_val = get_dataset_filename(datadir, dataset, "val", ds_dict.get("laugmented", False))
     with xr.open_dataset(fdata_val) as ds_val:
         ds_val = data_norm.normalize(ds_val)
-    da_val = HandleDataClass.reshape_ds(ds_val)
+    da_val = HandleDataClass.reshape_ds(ds_val).astype("float32", copy=True)
 
-    tfds_val = HandleDataClass.make_tf_dataset_allmem(da_val.astype("float32", copy=True), ds_dict["batch_size"],
+    tfds_val = HandleDataClass.make_tf_dataset_allmem(da_val, ds_dict["batch_size"],
                                                       ds_dict["predictands"], predictors=ds_dict.get("predictors", None),
                                                       lshuffle=True, var_tar2in=ds_dict["var_tar2in"],
                                                       named_targets=named_targets)
     
     # clean up to save some memory
-    free_mem([ds_val, da_val])
+    del ds_val
+    del da_val
+    gc.collect()
 
     tval_load = timer() - t0_val
     print(f"Validation data preparation time: {tval_load:.2f}s.")
@@ -177,21 +182,21 @@ def main(parser_args):
         ttrain_load = sum(ds_obj.reading_times) + tval_load
         print(f"Data loading time: {ttrain_load:.2f}s.")
         print(f"Average throughput: {ds_obj.ds_proc_size / 1.e+06 / training_times['Total training time']:.3f} MB/s")
-    benchmark_dict = {**{"data loading time": ttrain_load}, **training_times}
-    print(f"Model '{parser_args.exp_name}' training time: {training_times['Total training time']:.2f} s. " +
-          f"Save model to '{model_savedir}'")
 
     # save trained model
     t0_save = timer()
 
+    model_savedir_last = os.path.join(model_savedir, f"{parser_args.exp_name}_last")
     os.makedirs(model_savedir, exist_ok=True)
-    model.save(filepath=model_savedir)
+    model.save(filepath=model_savedir_last)
 
-    if callable(getattr(model, "plot_model", False)):
-        model.plot_model(model_savedir, show_shapes=True)
-    else:
+    #if callable(getattr(model, "plot_model", False)):
+    try:
+        model.plot_model(model_savedir, show_shapes=True) # , show_layer_actiavtions=True)
+    #else:
+    except:
         plot_model(model, os.path.join(model_savedir, f"plot_{parser_args.exp_name}.png"),
-                   show_shapes=True)
+                   show_shapes=True) #, show_layer_activations=True)
 
     # final timing
     tend = timer()
@@ -203,36 +208,6 @@ def main(parser_args):
     print_gpu_usage("Final GPU memory: ")
     print_cpu_usage("Final CPU memory: ")
 
-    # populate benchmark dictionary
-    benchmark_dict.update({"saving model time": saving_time, "total runtime": tot_run_time})
-    benchmark_dict.update({"job id": job_id, "#nodes": 1, "#cpus": len(os.sched_getaffinity(0)), "#gpus": 1,
-                           "#mpi tasks": 1, "node id": None, "max. gpu power": None, "gpu energy consumption": None})
-    try:
-        benchmark_dict["final training loss"] = get_loss_from_history(history, "loss")
-    except KeyError:
-        benchmark_dict["final training loss"] = get_loss_from_history(history, "recon_loss")
-    try:
-        benchmark_dict["final validation loss"] = get_loss_from_history(history, "val_loss")
-    except KeyError:
-        benchmark_dict["final validation loss"] = get_loss_from_history(history, "val_recon_loss")
-    # ... and save CSV-file with tracked data on disk
-    bm_obj.populate_csv_from_dict(benchmark_dict)
-
-    js_file = os.path.join(model_savedir, "benchmark_training_static.json")
-    if not os.path.isfile(js_file):
-        func_model_info = getattr(model, "get_model_info", None)
-        if callable(func_model_info):
-            model_info = func_model_info()
-        else:
-            model_info = {}
-        stat_info = {"static_model_info": model_info,
-                     "data_info": {"training data size": tfds_train_size, "validation data size": da_val.nbytes,
-                                   "nsamples": nsamples, "shape_samples": shape_in,
-                                   "batch_size": ds_dict["batch_size"]}}
-
-        with open(js_file, "w") as jsf:
-            js.dump(stat_info, jsf)
-
     print("Finished job at {0}".format(dt.strftime(dt.now(), "%Y-%m-%d %H:%M:%S")))
 
 
diff --git a/downscaling_ap5/models/custom_losses.py b/downscaling_ap5/models/custom_losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..46e5e537a2a40d4d61657c90da2581fe651d6f4b
--- /dev/null
+++ b/downscaling_ap5/models/custom_losses.py
@@ -0,0 +1,160 @@
+# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
+#
+# SPDX-License-Identifier: MIT
+
+"""
+Some custmoized losses (e.g. on vector quantities)
+"""
+
+__author__ = "Michael Langguth"
+__email__ = "m.langguth@fz-juelich.de"
+__date__ = "2023-06-16"
+__update__ = "2023-06-16"
+
+# import module
+import inspect
+import tensorflow as tf
+
+def fix_channels(n_channels):
+    """
+    Decorator to fix number of channels in loss functions.
+    """
+    def decorator(func):
+        def wrapper(y_true, y_pred, **func_kwargs):
+            return func(y_true, y_pred, n_channels, **func_kwargs)
+        return wrapper
+    return decorator
+
+def get_custom_loss(loss_name, **kwargs):
+    """
+    Loss factory including some customized losses and all available Keras losses
+    :param loss_name: name of the loss function
+    :return: the respective layer to deploy desired activation
+    """
+    known_losses = ["mse_channels", "mae_channels", "mae_vec", "mse_vec", "critic", "critic_generator"] + \
+                   [loss_cls[0] for loss_cls in inspect.getmembers(tf.keras.losses, inspect.isclass)]
+
+    loss_name = loss_name.lower()
+
+    n_channels = kwargs.get("n_channels", None)
+
+    if loss_name == "mse_channels":
+        assert n_channels > 0, f"n_channels must be a number larger than zero, but is {n_channels}."
+        loss_fn = fix_channels(**kwargs)(mse_channels)
+    elif loss_name == "mae_channels":
+        assert n_channels > 0, f"n_channels must be a number larger than zero, but is {n_channels}."
+        loss_fn = fix_channels(**kwargs)(mae_channels)
+    elif loss_name == "mae_vec":
+        assert n_channels > 0, f"n_channels must be a number larger than zero, but is {n_channels}."
+        loss_fn = fix_channels(**kwargs)(mae_vec)
+    elif loss_name == "mse_vec":
+        assert n_channels > 0, f"n_channels must be a number larger than zero, but is {n_channels}."
+        loss_fn = fix_channels(**kwargs)(mse_vec)
+    elif loss_name == "critic":
+        loss_fn = critic_loss
+    elif loss_name == "critic_generator":
+        loss_fn = critic_gen_loss
+    else:
+        try:
+            loss_fn = getattr(tf.keras.losses, loss_name)(**kwargs)
+        except AttributeError:
+            raise ValueError(f"{loss_name} is not a valid loss function. Choose one of the following: {known_losses}")
+
+    return loss_fn
+
+
+def mae_channels(x, x_hat, n_channels: int = None, channels_last: bool = True, avg_channels: bool = False):
+    rloss = 0.
+    if channels_last:
+        # get MAE for all output heads
+        for i in range(n_channels):
+            rloss += tf.reduce_mean(tf.abs(tf.squeeze(x_hat[..., i]) - x[..., i]))
+    else:
+        for i in range(n_channels):
+            rloss += tf.reduce_mean(tf.abs(tf.squeeze(x_hat[i, ...]) - x[i, ...]))
+            
+    if avg_channels:
+        rloss /= n_channels
+        
+    return rloss
+            
+def mse_channels(x, x_hat, n_channels, channels_last: bool = True, avg_channels: bool = False):
+    rloss = 0.
+    if channels_last:
+        # get MAE for all output heads
+        for i in range(n_channels):
+            rloss += tf.reduce_mean(tf.square(tf.squeeze(x_hat[..., i]) - x[..., i]))
+    else:
+        for i in range(n_channels):
+            rloss += tf.reduce_mean(tf.square(tf.squeeze(x_hat[i, ...]) - x[i, ...]))
+            
+    if avg_channels:
+        rloss /= n_channels
+        
+    return rloss
+
+def mae_vec(x, x_hat, n_channels, channels_last: bool = True, avg_channels: bool = False, nd_vec:  int = None):
+    
+    if nd_vec is None:
+        nd_vec = n_channels
+        
+    rloss = 0.
+    if channels_last:
+        vec_ind = -1
+        diff = tf.squeeze(x_hat[..., 0:nd_vec]) - x[..., 0:nd_vec]
+    else:
+        vec_ind = 1
+        diff = tf.squeeze(x_hat[:,0:nd_vec, ...]) - x[:,0:nd_vec, ...]
+    
+    rloss = tf.reduce_mean(tf.norm(diff, axis=vec_ind))
+    #rloss = tf.reduce_mean(tf.math.reduce_euclidean_norm(diff, axis=vec_ind))
+
+    if nd_vec > n_channels: 
+        if channels_last:
+            rloss += mae_channels(x[..., nd_vec::], x_hat[..., nd_vec::], True, avg_channels)
+        else:
+            rloss += mae_channels(x[:, nd_vec::, ...], x_hat[:, nd_vec::, ...], True, avg_channels)    
+            
+    return rloss
+
+def mse_vec(x, x_hat, n_channels, channels_last: bool = True, avg_channels: bool = False, nd_vec:  int = None):
+    
+    if nd_vec is None:
+        nd_vec = n_channels
+        
+    rloss = 0.
+    
+    if channels_last:
+        vec_ind = -1
+        diff = tf.squeeze(x_hat[..., 0:nd_vec]) - x[..., 0:nd_vec]
+    else:
+        vec_ind = 1
+        diff = tf.squeeze(x_hat[:,0:nd_vec, ...]) - x[:,0:nd_vec, ...]
+        
+    rloss = tf.reduce_mean(tf.square(tf.norm(diff, axis=vec_ind)))
+    
+    if nd_vec > n_channels: 
+        if channels_last:
+            rloss += mse_channels(x[..., nd_vec::], x_hat[..., nd_vec::], True, avg_channels)
+        else:
+            rloss += mse_channels(x[:, nd_vec::, ...], x_hat[:, nd_vec::, ...], True, avg_channels)   
+            
+    return rloss
+
+def critic_loss(critic_real, critic_gen):
+    """
+    The critic is optimized to maximize the difference between the generated and the real data max(real - gen).
+    This is equivalent to minimizing the negative of this difference, i.e. min(gen - real) = max(real - gen)
+    :param critic_real: critic on the real data
+    :param critic_gen: critic on the generated data
+    :return c_loss: loss to optize the critic
+    """
+    c_loss = tf.reduce_mean(critic_gen - critic_real)
+
+    return c_loss
+
+
+def critic_gen_loss(critic_gen):
+    cg_loss = -tf.reduce_mean(critic_gen)
+
+    return cg_loss
\ No newline at end of file
diff --git a/downscaling_ap5/models/model_utils.py b/downscaling_ap5/models/model_utils.py
index 5011e9cdd2eaeab3d6bdf965ed80d6c69b86d325..0ee63045c1ed817ff3bd927642ea14b5c897562f 100644
--- a/downscaling_ap5/models/model_utils.py
+++ b/downscaling_ap5/models/model_utils.py
@@ -9,10 +9,11 @@ Some auxiliary methods to create Keras models.
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-05-26"
-__update__ = "2023-03-10"
+__update__ = "2023-10-12"
 
 # import modules
 from timeit import default_timer as timer
+import xarray as xr
 import tensorflow.keras as keras
 from unet_model import sha_unet, UNET
 from wgan_model import WGAN, critic_model
@@ -137,3 +138,28 @@ def handle_opt_utils(model: keras.Model, opt_funcname: str):
     return opt_dict
 
 
+def convert_to_xarray(mout_np, norm, varname, coords, dims, z_branch=False):
+    """
+    Converts numpy-array of model output to xarray.DataArray and performs denormalization.
+    :param mout_np: numpy-array of model output
+    :param norm: normalization object
+    :param varname: name of variable
+    :param coords: coordinates of target data
+    :param dims: dimensions of target data
+    :param z_branch: flag for z-branch
+    :return: xarray.DataArray of model output with denormalized data
+    """
+    if z_branch:
+        # slice data to get first channel only
+        if isinstance(mout_np, list): mout_np = mout_np[0]
+        mout_xr = xr.DataArray(mout_np[..., 0].squeeze(), coords=coords, dims=dims, name=varname)
+    else:
+        # no slicing required
+        mout_xr = xr.DataArray(mout_np.squeeze(), coords=coords, dims=dims, name=varname)
+
+    # perform denormalization
+    mout_xr = norm.denormalize(mout_xr, varname=varname)
+
+    return mout_xr
+
+
diff --git a/downscaling_ap5/models/unet_model.py b/downscaling_ap5/models/unet_model.py
index 60ee0ffdc30929bd822cd352fe5f747e8c002655..0465a89f3fd17357dbf92b6d276c9347fbc4cf76 100644
--- a/downscaling_ap5/models/unet_model.py
+++ b/downscaling_ap5/models/unet_model.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
+# SPDX-FileCopyrightText: 2023 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
 #
 # SPDX-License-Identifier: MIT
 
@@ -9,28 +9,29 @@ Methods to set-up U-net models incl. its building blocks.
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2021-XX-XX"
-__update__ = "2023-05-06"
+__update__ = "2023-11-22"
 
 # import modules
 import os
 from typing import List
+import inspect
 import numpy as np
 import tensorflow as tf
 import tensorflow.keras as keras
 # all the layers used for U-net
 from tensorflow.keras.layers import (Concatenate, Conv2D, Conv2DTranspose, Input, MaxPool2D, BatchNormalization,
-                                     Activation)
+                                     Activation, AveragePooling2D)
 from tensorflow.keras.models import Model
 from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint, EarlyStopping
 
-import advanced_activations
+from advanced_activations import advanced_activation
 from other_utils import to_list
 
 # building blocks for Unet
 
 
 def conv_block(inputs, num_filters: int, kernel: tuple = (3, 3), strides: tuple = (1, 1), padding: str = "same",
-               activation: str = "relu", activation_args=None, kernel_init: str = "he_normal",
+               activation: str = "swish", activation_args={}, kernel_init: str = "he_normal",
                l_batch_normalization: bool = True):
     """
     A convolutional layer with optional batch normalization
@@ -51,101 +52,139 @@ def conv_block(inputs, num_filters: int, kernel: tuple = (3, 3), strides: tuple
     try:
         x = Activation(activation)(x)
     except ValueError:
-        ac_layer = advanced_activations(activation, *activation_args)
+        ac_layer = advanced_activation(activation, *activation_args)
         x = ac_layer(x)
 
     return x
 
 
-def conv_block_n(inputs, num_filters: int, n: int = 2, kernel: tuple = (3, 3), strides: tuple = (1, 1),
-                 padding: str = "same", activation: str = "relu", activation_args=None,
-                 kernel_init: str = "he_normal", l_batch_normalization: bool = True):
+def conv_block_n(inputs, num_filters: int, n: int = 2, **kwargs):
     """
     Sequential application of two convolutional layers (using conv_block).
     :param inputs: the input data with dimensions nx, ny and nc
     :param num_filters: number of filters (output channel dimension)
     :param n: number of convolutional blocks
-    :param kernel: tuple for convolution kernel size
-    :param strides: tuple for stride of convolution
-    :param padding: technique for padding (e.g. "same" or "valid")
-    :param activation: activation fuction for neurons (e.g. "relu")
-    :param activation_args: arguments for activation function given that advanced layers are applied
-    :param kernel_init: initialization technique (e.g. "he_normal" or "glorot_uniform")
-    :param l_batch_normalization: flag if batch normalization should be applied
+    :param kwargs: keyword arguments for conv_block
     """
-    x = conv_block(inputs, num_filters, kernel, strides, padding, activation, activation_args,
-                   kernel_init, l_batch_normalization)
+    x = conv_block(inputs, num_filters, **kwargs)
     for _ in np.arange(n - 1):
-        x = conv_block(x, num_filters, kernel, strides, padding, activation, activation_args,
-                       kernel_init, l_batch_normalization)
+        x = conv_block(x, num_filters, **kwargs)
 
     return x
 
 
-def encoder_block(inputs, num_filters, kernel_maxpool: tuple = (2, 2), l_large: bool = True):
+def encoder_block(inputs, num_filters, l_large: bool = True, kernel_pool: tuple = (2, 2), l_avgpool: bool = False, **kwargs):
     """
     One complete encoder-block used in U-net.
     :param inputs: input to encoder block
     :param num_filters: number of filters/channel to be used in convolutional blocks
-    :param kernel_maxpool: kernel used in max-pooling
     :param l_large: flag for large encoder block (two consecutive convolutional blocks)
+    :param kernel_maxpool: kernel used in max-pooling
+    :param l_avgpool: flag if average pooling is used instead of max pooling
+    :param kwargs: keyword arguments for conv_block
     """
     if l_large:
-        x = conv_block_n(inputs, num_filters, n=2)
+        x = conv_block_n(inputs, num_filters, n=2, **kwargs)
     else:
-        x = conv_block(inputs, num_filters)
+        x = conv_block(inputs, num_filters, **kwargs)
 
-    p = MaxPool2D(kernel_maxpool)(x)
+    if l_avgpool:
+        p = AveragePooling2D(kernel_pool)(x)
+    else:
+        p = MaxPool2D(kernel_pool)(x)
 
     return x, p
 
+def subpixel_block(inputs, num_filters, kernel: tuple = (3,3), upscale_fac: int = 2,
+        padding: str = "same", activation: str = "swish", activation_args: dict = {},
+        kernel_init: str = "he_normal"):
 
-def decoder_block(inputs, skip_features, num_filters, kernel: tuple = (3, 3), strides_up: int = 2,
-                  padding: str = "same", activation="relu", kernel_init="he_normal",
-                  l_batch_normalization: bool = True):
+    x = Conv2D(num_filters * (upscale_fac ** 2), kernel, padding=padding, kernel_initializer=kernel_init,
+               activation=None)(inputs)
+    try:
+        x = Activation(activation)(x)
+    except ValueError:
+        ac_layer = advanced_activation(activation, *activation_args)
+        x = ac_layer(x)
+    
+    x = tf.nn.depth_to_space(x, upscale_fac)
+
+    return x
+
+
+
+def decoder_block(inputs, skip_features, num_filters, strides_up: int = 2, l_subpixel: bool = False, **kwargs_conv_block):
     """
     One complete decoder block used in U-net (reverting the encoder)
     """
-    x = Conv2DTranspose(num_filters, (strides_up, strides_up), strides=strides_up, padding="same")(inputs)
+    if l_subpixel:
+        kwargs_subpixel = kwargs_conv_block.copy()
+        for ex_key in ["strides", "l_batch_normalization"]:
+            kwargs_subpixel.pop(ex_key, None)
+        x = subpixel_block(inputs, num_filters, upscale_fac=strides_up, **kwargs_subpixel)
+    else:
+        x = Conv2DTranspose(num_filters, (strides_up, strides_up), strides=strides_up, padding="same")(inputs)
+        
+        activation = kwargs_conv_block.get("activation", "relu")
+        activation_args = kwargs_conv_block.get("activation_args", {})
+
+        try:
+            x = Activation(activation)(x)
+        except ValueError:
+            ac_layer = advanced_activation(activation, *activation_args)
+            x = ac_layer(x)
+
     x = Concatenate()([x, skip_features])
-    x = conv_block_n(x, num_filters, 2, kernel, (1, 1), padding, activation, kernel_init=kernel_init, 
-                     l_batch_normalization=l_batch_normalization)
+    x = conv_block_n(x, num_filters, 2, **kwargs_conv_block)
 
     return x
 
 
 # The particular U-net
-def sha_unet(input_shape: tuple, n_predictands_dyn: int, channels_start: int = 56, z_branch: bool = False,
-             concat_out: bool = False, tar_channels=["output_dyn", "output_z"]) -> Model:
+def sha_unet(input_shape: tuple, n_predictands_dyn: int, hparams_unet: dict, concat_out: bool = False,
+             tar_channels=["output_dyn", "output_z"]) -> Model:
     """
     Builds up U-net model architecture adapted from Sha et al., 2020 (see https://doi.org/10.1175/JAMC-D-20-0057.1).
     :param input_shape: shape of input-data
     :param channels_start: number of channels to use as start in encoder blocks
     :param n_predictands: number of target variables (dynamic output variables)
     :param z_branch: flag if z-branch is used.
-    :param tar_channels: name of output/target channels (needed for associating losses during compilation)
+    :param advanced_unet: flag if advanced U-net is used (LeakyReLU instead of ReLU, average pooling instead of max pooling and subpixel-layer)
     :param concat_out: boolean if output layers will be concatenated (disables named target channels!)
+    :param tar_channels: name of output/target channels (needed for associating losses during compilation)
     :return:
     """
+    # basic configuration of U-Net 
+    channels_start = hparams_unet["ngf"]
+    z_branch = hparams_unet["z_branch"]
+    kernel_pool = hparams_unet["kernel_pool"]   
+    l_avgpool = hparams_unet["l_avgpool"]
+    l_subpixel = hparams_unet["l_subpixel"]
+
+    config_conv = {"kernel": hparams_unet["kernel"], "strides": hparams_unet["strides"], "padding": hparams_unet["padding"], 
+                   "activation": hparams_unet["activation"], "activation_args": hparams_unet["activation_args"], 
+                   "kernel_init": hparams_unet["kernel_init"], "l_batch_normalization": hparams_unet["l_batch_normalization"]}
+
+    # build U-Net
     inputs = Input(input_shape)
 
     """ encoder """
-    s1, e1 = encoder_block(inputs, channels_start, l_large=True)
-    s2, e2 = encoder_block(e1, channels_start * 2, l_large=False)
-    s3, e3 = encoder_block(e2, channels_start * 4, l_large=False)
+    s1, e1 = encoder_block(inputs, channels_start, l_large=True, kernel_pool=kernel_pool, l_avgpool=l_avgpool,**config_conv)
+    s2, e2 = encoder_block(e1, channels_start * 2, l_large=False, kernel_pool=kernel_pool, l_avgpool=l_avgpool,**config_conv)
+    s3, e3 = encoder_block(e2, channels_start * 4, l_large=False, kernel_pool=kernel_pool, l_avgpool=l_avgpool,**config_conv)
 
     """ bridge encoder <-> decoder """
-    b1 = conv_block(e3, channels_start * 8)
+    b1 = conv_block(e3, channels_start * 8, **config_conv)
 
     """ decoder """
-    d1 = decoder_block(b1, s3, channels_start * 4)
-    d2 = decoder_block(d1, s2, channels_start * 2)
-    d3 = decoder_block(d2, s1, channels_start)
+    d1 = decoder_block(b1, s3, channels_start * 4, l_subpixel=l_subpixel, **config_conv)
+    d2 = decoder_block(d1, s2, channels_start * 2, l_subpixel=l_subpixel, **config_conv)
+    d3 = decoder_block(d2, s1, channels_start, l_subpixel=l_subpixel, **config_conv)
 
-    output_dyn = Conv2D(n_predictands_dyn, (1, 1), kernel_initializer="he_normal", name=tar_channels[0])(d3)
+    output_dyn = Conv2D(n_predictands_dyn, (1, 1), kernel_initializer=config_conv["kernel_init"], name=tar_channels[0])(d3)
     if z_branch:
         print("Use z_branch...")
-        output_static = Conv2D(1, (1, 1), kernel_initializer="he_normal", name=tar_channels[1])(d3)
+        output_static = Conv2D(1, (1, 1), kernel_initializer=config_conv["kernel_init"], name=tar_channels[1])(d3)
 
         if concat_out:
             model = Model(inputs, tf.concat([output_dyn, output_static], axis=-1), name="downscaling_unet_with_z")
@@ -173,7 +212,7 @@ class UNET(keras.Model):
         self.hparams = UNET.get_hparams_dict(hparams)
         self.n_predictands = len(varnames_tar)                      # number of predictands
         self.n_predictands_dyn = self.n_predictands - 1 if self.hparams["z_branch"] else self.n_predictands
-        if self.hparams["l_embed"]:
+        if self.hparams.get("l_embed", False):
             raise ValueError("Embedding is not implemented yet.")
         self.modelname = exp_name
         if not os.path.isdir(savedir):
@@ -186,7 +225,6 @@ class UNET(keras.Model):
         See https://stackoverflow.com/questions/65318036/is-it-possible-to-use-the-tensorflow-keras-functional-api-train_unet-model-err.6387845within-a-subclassed-mo
         for a reference how a model based on Keras functional API has to be integrated into a subclass.
         """
-        print(**kwargs)
         return self.unet(inputs, **kwargs)
 
     def get_compile_opts(self):
@@ -210,13 +248,14 @@ class UNET(keras.Model):
         return opt_dict
 
     def compile(self, **kwargs):
+
         # instantiate model
         if self.hparams["z_branch"]:     # model has named branches (see also opt_dict in get_compile_opts)
             tar_channels = [f"{var}" for var in self.varnames_tar]
-            self.unet = self.unet(self.shape_in, self.n_predictands_dyn, z_branch=True,
-                                  concat_out= False, tar_channels=tar_channels)
+            self.unet = self.unet(self.shape_in, self.n_predictands_dyn, self.hparams,
+                                  concat_out=False, tar_channels=tar_channels)
         else:
-            self.unet = self.unet(self.shape_in, self.n_predictands_dyn, z_branch=False)
+            self.unet = self.unet(self.shape_in, self.n_predictands_dyn, self.hparams)
 
         return self.unet.compile(**kwargs)
        # return super(UNET, self).compile(**kwargs)
@@ -257,7 +296,9 @@ class UNET(keras.Model):
             callback_list = callback_list + [LearningRateScheduler(self.get_lr_scheduler(), verbose=1)]
 
         if self.hparams["lscheduled_train"]:
-            callback_list = callback_list + [ModelCheckpoint(self.savedir, monitor="val_loss", verbose=1,
+            savedir_best = os.path.join(self.savedir, f"{self.modelname}_best")
+            os.makedirs(savedir_best, exist_ok=True)
+            callback_list = callback_list + [ModelCheckpoint(savedir_best, monitor="val_t_2m_tar_loss", verbose=1,
                                                              save_best_only=True, mode="min")]
             #                                + EarlyStopping(monitor="val_recon_loss", patience=8)]
 
@@ -279,6 +320,9 @@ class UNET(keras.Model):
     def save(self, **kwargs):
         self.unet.save(**kwargs)
 
+    def plot_model(self, **kwargs):
+        self.unet.plot_model(**kwargs)
+
     @staticmethod
     def get_hparams_dict(hparams_user: dict) -> dict:
         """
@@ -312,10 +356,13 @@ class UNET(keras.Model):
         """
         Return default hyperparameter dictionary.
         """
-        hparams_dict = {"batch_size": 32, "lr": 5.e-05, "nepochs": 70, "z_branch": True, "loss_func": "mae",
-                        "loss_weights": [1.0, 1.0], "lr_decay": False, "decay_start": 5, "decay_end": 30,
-                        "lr_end": 1.e-06, "l_embed": False, "ngf": 56, "optimizer": "adam", "lscheduled_train": True,
-                        "var_tar2in": "", "n_predictands": 1}
+
+        hparams_dict = {"kernel": (3, 3), "strides": (1, 1), "padding": "same", "activation": "swish", "activation_args": {},      # arguments for building blocks of U-Net:
+                        "kernel_init": "he_normal", "l_batch_normalization": True, "kernel_pool": (2, 2), "l_avgpool": True,     # see keyword-aguments of sha_unet, conv_block,
+                        "l_subpixel": True, "z_branch": True, "ngf": 56,                                                         # encoder_block and decoder_block
+                        "batch_size": 32, "lr": 5.e-05, "nepochs": 70, "loss_func": "mae", "loss_weights": [1.0, 1.0],            # training parameters
+                        "lr_decay": False, "decay_start": 5, "decay_end": 30, "lr_end": 1.e-06, "l_embed": False, 
+                        "optimizer": "adam", "lscheduled_train": True}
 
         return hparams_dict
 
diff --git a/downscaling_ap5/models/wgan_model.py b/downscaling_ap5/models/wgan_model.py
index 1e14df96c255b6fb96798099b3f2ba0176ad3eec..53ea658af545b465b26391fb4cf2db440c39a3b3 100644
--- a/downscaling_ap5/models/wgan_model.py
+++ b/downscaling_ap5/models/wgan_model.py
@@ -5,7 +5,7 @@
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-05-19"
-__update__ = "2022-11-25"
+__update__ = "2023-08-08"
 
 import os, sys
 from typing import List, Tuple, Union
@@ -23,14 +23,15 @@ from tensorflow.python.platform import tf_logging as logging
 from tensorflow.keras.layers import (Input, Dense, GlobalAveragePooling2D)
 from tensorflow.keras.models import Model
 # other modules
+from custom_losses import get_custom_loss
 from unet_model import conv_block
-from other_utils import to_list
+from other_utils import to_list, merge_dicts
 
 list_or_tuple = Union[List, Tuple]
 
 
 def critic_model(shape, num_conv: int = 4, channels_start: int = 64, kernel: tuple = (3, 3),
-                 stride: tuple = (2, 2), activation: str = "relu", lbatch_norm: bool = True):
+                 stride: tuple = (2, 2), activation: str = "swish", lbatch_norm: bool = True):
     """
     Set-up convolutional discriminator model that is followed by two dense-layers
     :param shape: input shape of data (either real or generated data)
@@ -56,7 +57,7 @@ def critic_model(shape, num_conv: int = 4, channels_start: int = 64, kernel: tup
 
     # finally perform global average pooling and finalize by fully connected layers
     x = GlobalAveragePooling2D()(x)
-    x = Dense(channels_start)(x)
+    x = Dense(channels_start, activation=activation)(x)
     # ... and end with linear output layer
     out = Dense(1, activation="linear")(x)
 
@@ -91,7 +92,7 @@ class WGAN(keras.Model):
         if self.hparams["l_embed"]:
             raise ValueError("Embedding is not implemented yet.")
         self.n_predictands = len(varnames_tar)                      # number of predictands
-        self.n_predictands_dyn = self.n_predictands - 1 if self.hparams["z_branch"] else self.n_predictands
+        self.n_predictands_dyn = self.n_predictands - 1 if self.hparams["hparams_generator"]["z_branch"] else self.n_predictands
         print(f"Dynamic predictands: {self.n_predictands_dyn}, Predictands: {self.n_predictands}")
         self.modelname = exp_name
         if not os.path.isdir(savedir):
@@ -100,10 +101,16 @@ class WGAN(keras.Model):
 
         # instantiate submodels
         # instantiate model components (generator and critci)
-        self.generator = self.generator(self.shape_in, self.n_predictands, channels_start=self.hparams["ngf"],
-                                        concat_out=True, z_branch=self.hparams["z_branch"])
+        self.generator = self.generator(self.shape_in, self.n_predictands_dyn, self.hparams["hparams_generator"], concat_out=True)
         tar_shape = (*self.shape_in[:-1], self.n_predictands_dyn)   # critic only accounts for dynamic predictands
-        self.critic = self.critic(tar_shape)
+        critic_kwargs = self.hparams["hparams_critic"].copy()
+        critic_kwargs.pop("lr", None)
+        self.critic = self.critic(tar_shape, **critic_kwargs)
+
+        # losses
+        self.critic_loss = get_custom_loss("critic")
+        self.critic_gen_loss = get_custom_loss("critic_generator")
+        self.recon_loss = self.get_recon_loss() 
 
         # Unused attribute, but introduced for joint driver script with U-Net; to be solved with customized target vars
         self.varnames_tar = None
@@ -113,6 +120,19 @@ class WGAN(keras.Model):
         self.lr_scheduler = None
         self.checkpoint, self.earlystopping = None, None
 
+    def get_recon_loss(self):
+
+        kwargs_loss = {}
+        if "vec" in self.hparams["recon_loss"]:
+            kwargs_loss = {"nd_vec": self.hparams.get("nd_vec", 2), "n_channels": self.n_predictands}
+        elif "channels" in self.hparams["recon_loss"]:
+            kwargs_loss = {"n_channels": self.n_predictands}
+        
+        loss_fn = get_custom_loss(self.hparams["recon_loss"], **kwargs_loss)
+
+        return loss_fn
+
+
     def compile(self, **kwargs):
         """
         Instantiate generator and critic, compile model and then set optimizer as well optional learning rate decay and
@@ -143,7 +163,7 @@ class WGAN(keras.Model):
         :return: learning rate scheduler
         """
         decay_st, decay_end = self.hparams["decay_start"], self.hparams["decay_end"]
-        lr_start, lr_end = self.hparams["lr_gen"], self.hparams["lr_gen_end"]
+        lr_start, lr_end = self.hparams["hparams_generator"]["lr"], self.hparams["hparams_generator"]["lr_end"]
 
         if not decay_end > decay_st:
             raise ValueError("Epoch for end of learning rate decay must be large than start epoch. " +
@@ -206,7 +226,7 @@ class WGAN(keras.Model):
                 critic_gen = self.critic(gen_data[..., 0:self.n_predictands_dyn], training=True)
                 critic_gt = self.critic(predictands_critic, training=True)
                 # calculate the loss (incl. gradient penalty)
-                c_loss = WGAN.critic_loss(critic_gt, critic_gen)
+                c_loss = self.critic_loss(critic_gt, critic_gen)
                 gp = self.gradient_penalty(predictands_critic, gen_data[..., 0:self.n_predictands_dyn])
 
                 d_loss = c_loss + self.hparams["gp_weight"] * gp
@@ -221,7 +241,7 @@ class WGAN(keras.Model):
             gen_data = self.generator(predictors[-self.hparams["batch_size"]:, :, :, :], training=True)
             # get the critic and calculate corresponding generator losses (critic and reconstruction loss)
             critic_gen = self.critic(gen_data[..., 0:self.n_predictands_dyn], training=True)
-            cg_loss = WGAN.critic_gen_loss(critic_gen)
+            cg_loss = self.critic_gen_loss(critic_gen)
             rloss = self.recon_loss(predictands[-self.hparams["batch_size"]:, :, :, :], gen_data)
 
             g_loss = cg_loss + self.hparams["recon_weight"] * rloss
@@ -275,14 +295,6 @@ class WGAN(keras.Model):
 
         return gp
 
-    def recon_loss(self, real_data, gen_data):
-        # initialize reconstruction loss
-        rloss = 0.
-        # get MAE for all output heads
-        for i in range(self.hparams["n_predictands"]):
-            rloss += tf.reduce_mean(tf.abs(tf.squeeze(gen_data[..., i]) - real_data[..., i]))
-
-        return rloss
 
     def plot_model(self, save_dir, **kwargs):
         """
@@ -340,28 +352,22 @@ class WGAN(keras.Model):
 
         # check if parsed hyperparameters are known
         unknown_keys = [key for key in hparams_user.keys() if key not in hparams_default]
-        if unknown_keys:
+        if unknown_keys:        # if unknown keys are found, remove them from user-defined hyperparameters
             print("The following parsed hyperparameters are unknown and thus are ignored: {0}".format(
                 ", ".join(unknown_keys)))
+            [hparams_user.pop(unknown_key) for unknown_key in unknown_keys]
+            
+        hparams_dict = merge_dicts(hparams_default, hparams_user)   # merge default and user-defined hyperparameters
 
-        # get complete hyperparameter dictionary while checking type of parsed values
-        hparams_merged = {**hparams_default, **hparams_user}
-        hparams_dict = {}
-        for key in hparams_default:
-            if isinstance(hparams_merged[key], type(hparams_default[key])):
-                hparams_dict[key] = hparams_merged[key]
-            else:
-                raise TypeError("Parsed hyperparameter '{0}' must be of type '{1}', but is '{2}'"
-                                .format(key, type(hparams_default[key]), type(hparams_merged[key])))
-
+        # check if optimizer is valid and set corresponding optimizers for generator and critic
         if hparams_dict["optimizer"].lower() == "adam":
             adam = keras.optimizers.Adam
-            hparams_dict["d_optimizer"] = adam(learning_rate=hparams_dict["lr_critic"], beta_1=0.0, beta_2=0.9)
-            hparams_dict["g_optimizer"] = adam(learning_rate=hparams_dict["lr_gen"], beta_1=0.0, beta_2=0.9)
+            hparams_dict["d_optimizer"] = adam(learning_rate=hparams_dict["hparams_critic"]["lr"], beta_1=0.0, beta_2=0.9)
+            hparams_dict["g_optimizer"] = adam(learning_rate=hparams_dict["hparams_generator"]["lr"], beta_1=0.0, beta_2=0.9)
         elif hparams_dict["optimizer"].lower() == "rmsprop":
             rmsprop = keras.optimizers.RMSprop
-            hparams_dict["d_optimizer"] = rmsprop(lr=hparams_dict["lr_critic"])  # increase beta-values ?
-            hparams_dict["g_optimizer"] = rmsprop(lr=hparams_dict["lr_gen"])
+            hparams_dict["d_optimizer"] = rmsprop(lr=hparams_dict["hparams_critic"]["lr"])  # increase beta-values ?
+            hparams_dict["g_optimizer"] = rmsprop(lr=hparams_dict["hparams_generator"]["lr"])
         else:
             raise ValueError("'{0}' is not a valid optimizer. Either choose Adam or RMSprop-optimizer")
 
@@ -372,32 +378,18 @@ class WGAN(keras.Model):
         """
         Return default hyperparameter dictionary.
         """
-        hparams_dict = {"batch_size": 32, "lr_gen": 1.e-05, "lr_critic": 1.e-06, "nepochs": 50, "z_branch": False,
-                        "lr_decay": False, "decay_start": 5, "decay_end": 10, "lr_gen_end": 1.e-06, "l_embed": False,
-                        "ngf": 56, "d_steps": 5, "recon_weight": 1000., "gp_weight": 10., "optimizer": "adam", 
-                        "lscheduled_train": True, "var_tar2in": "", "n_predictands": 2}
-
+        hparams_dict = {"batch_size": 32, "nepochs": 50, "lr_decay": False, "decay_start": 5, "decay_end": 10, 
+                        "l_embed": False, "d_steps": 5, "recon_weight": 1000., "gp_weight": 10., "optimizer": "adam", 
+                        "lscheduled_train": True, "recon_loss": "mae_channels", 
+                        "hparams_generator": {"kernel": (3, 3), "strides": (1, 1), "padding": "same", "activation": "swish", "activation_args": {},      # arguments for building blocks of U-Net:
+                                              "kernel_init": "he_normal", "l_batch_normalization": True, "kernel_pool": (2, 2), "l_avgpool": True,     # see keyword-aguments of sha_unet, conv_block,
+                                              "l_subpixel": True, "z_branch": True, "ngf": 56, "lr": 1.e-05, "lr_end": 1.e-06}, 
+                        "hparams_critic": {"num_conv": 4, "channels_start": 64, "activation": "swish",
+                                           "lbatch_norm": True, "kernel": (3, 3), "stride": (2, 2), "lr": 1.e-06,}
+                       }
+                     
         return hparams_dict
 
-    @staticmethod
-    def critic_loss(critic_real, critic_gen):
-        """
-        The critic is optimized to maximize the difference between the generated and the real data max(real - gen).
-        This is equivalent to minimizing the negative of this difference, i.e. min(gen - real) = max(real - gen)
-        :param critic_real: critic on the real data
-        :param critic_gen: critic on the generated data
-        :return c_loss: loss to optize the critic
-        """
-        c_loss = tf.reduce_mean(critic_gen - critic_real)
-
-        return c_loss
-
-    @staticmethod
-    def critic_gen_loss(critic_gen):
-        cg_loss = -tf.reduce_mean(critic_gen)
-
-        return cg_loss
-
 
 class LearningRateSchedulerWGAN(LearningRateScheduler):
 
diff --git a/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_gradient_ratio.ipynb b/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_gradient_ratio.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..1ea138b3d136b972b2330dd3a1949491e3baa5c8
--- /dev/null
+++ b/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_gradient_ratio.ipynb
@@ -0,0 +1,183 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6d799556-9628-4ed8-acfa-5a8f31190e5c",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "%matplotlib inline\n",
+    "\n",
+    "import os\n",
+    "import numpy as np\n",
+    "import xarray as xr\n",
+    "import pandas as pd\n",
+    "\n",
+    "import matplotlib.pyplot as plt"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8a64d03c-882d-4a07-9fcc-2a15adbb10fa",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "# auxiliary methods\n",
+    "\n",
+    "def create_lines_plot(data: xr.DataArray, data_std: xr.DataArray, model_names: str, metric: dict,\n",
+    "                      plt_fname: str, x_coord: str = \"hour\", **kwargs):\n",
+    "\n",
+    "    # get some plot parameters\n",
+    "    linestyle = kwargs.get(\"linestyle\", [\"k-\", \"b-\"])\n",
+    "    err_col = kwargs.get(\"error_color\", [\"grey\", \"blue\"])\n",
+    "    val_range = kwargs.get(\"value_range\", (0.7, 1.1))\n",
+    "    fs = kwargs.get(\"fs\", 16)\n",
+    "    ref_line = kwargs.get(\"ref_line\", None)\n",
+    "    ref_linestyle = kwargs.get(\"ref_linestyle\", \"k--\")\n",
+    "    \n",
+    "    fig, (ax) = plt.subplots(1, 1, figsize=(8, 6))\n",
+    "    for i, exp in enumerate(data[\"exp\"]):\n",
+    "        ax.plot(data[x_coord].values, data.sel({\"exp\": exp}).values, linestyle[i],\n",
+    "                label=model_names[i])\n",
+    "        ax.fill_between(data[x_coord].values, data.sel({\"exp\": exp}).values-data_std.sel({\"exp\": exp}).values,\n",
+    "                        data.sel({\"exp\": exp}).values+data_std.sel({\"exp\": exp}).values, facecolor=err_col[i],\n",
+    "                        alpha=0.2)\n",
+    "    if ref_line is not None:\n",
+    "        nval = np.shape(data[x_coord].values)[0]\n",
+    "        ax.plot(data[x_coord].values, np.full(nval, ref_line), ref_linestyle)\n",
+    "    ax.set_ylim(*val_range)\n",
+    "    ax.set_yticks(np.arange(*val_range, 0.05))\n",
+    "    # label axis\n",
+    "    ax.set_xlabel(\"daytime [UTC]\", fontsize=fs)\n",
+    "    metric_name, metric_unit = list(metric.keys())[0], list(metric.values())[0]\n",
+    "    ax.set_ylabel(f\"{metric_name} T2m [{metric_unit}]\", fontsize=fs)\n",
+    "    ax.tick_params(axis=\"both\", which=\"both\", direction=\"out\", labelsize=fs-2)\n",
+    "    ax.legend(fontsize=fs-2, loc=\"upper right\")\n",
+    "\n",
+    "    # save plot and close figure\n",
+    "    plt_fname = plt_fname + \".png\" if not plt_fname.endswith(\".png\") else plt_fname\n",
+    "    print(f\"Save plot in file '{plt_fname}'\")\n",
+    "    #plt.tight_layout()\n",
+    "    #fig.savefig(plt_fname)\n",
+    "    fig.savefig(plt_fname, bbox_inches=\"tight\")\n",
+    "    plt.close(fig)\n",
+    "\n",
+    "def get_id_from_fname(fname):\n",
+    "    try:\n",
+    "        start_index = fname.find(\"id\") + 2            # Adding 2 to move past \"id\"\n",
+    "        end_index = fname.find(\"_\", start_index)\n",
+    "        \n",
+    "        exp_id = fname[start_index:end_index]\n",
+    "    except:\n",
+    "        raise ValueError(f\"Failed to deduce experiment ID from '{fname}'\")\n",
+    "        \n",
+    "    return exp_id"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d8470099-80b7-46b5-bb68-933d5549ec0b",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "# parameters\n",
+    "results_basedir = \"/p/home/jusers/langguth1/juwels/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/results\"\n",
+    "plt_dir = os.path.join(results_basedir, \"meta\")\n",
+    "\n",
+    "exp1 = \"wgan_t2m_atmorep_test\"\n",
+    "exp2 = \"atmorep_id26n32cey\"\n",
+    "\n",
+    "varname = \"T2m\"\n",
+    "year = 2018"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c05778d6-0835-40b9-8010-149c05e00519",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "# main\n",
+    "os.makedirs(plt_dir, exist_ok=True)\n",
+    "\n",
+    "fexp1 = os.path.join(results_basedir, exp1, \"metric_files\", \"eval_grad_amplitude_year.csv\")\n",
+    "fexp2 = os.path.join(results_basedir, exp2, \"metric_files\", \"eval_grad_amplitude__small_dom_year.csv\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3634bcd3-40f7-4497-9561-3f41c00fc039",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "dims = [\"hour\", \"type\"]\n",
+    "coord_dict = {\"hour\": np.arange(24), \"type\": [\"mean\", \"std\"]}\n",
+    "\n",
+    "da_gr_exp1 = xr.DataArray(pd.read_csv(fexp1, header=0, index_col=0), dims=dims, coords=coord_dict)\n",
+    "da_gr_exp2 = xr.DataArray(pd.read_csv(fexp2, header=0, index_col=0), dims=dims, coords=coord_dict)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "492e8f95-e00f-43f6-951e-72e9a2a60d35",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "da_gr_all = xr.concat([da_gr_exp1, da_gr_exp2], dim= \"exp\")\n",
+    "da_gr_all = da_gr_all.assign_coords({\"exp\": [exp1, exp2]})\n",
+    "\n",
+    "# create plot\n",
+    "plt_fname = os.path.join(plt_dir, f\"eval_grad_amplitude_{exp1}_{exp2}.png\")\n",
+    "create_lines_plot(da_gr_all.sel({\"type\": \"mean\"}), da_gr_all.sel({\"type\": \"std\"}),\n",
+    "                  [\"WGAN\", \"AtmoRep\"], {\"GRAD_AMPLITUDE\": \"1\"}, plt_fname, re)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f0bb44d2-85eb-4bd8-9922-ba3ce5ffd38d",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "langguth1_downscaling_kernel_juwels",
+   "language": "python",
+   "name": "langguth1_downscaling_kernel_juwels"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_rmse.ipynb b/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_rmse.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..8b68161292e052e37b22bf578cd7d48a6e593676
--- /dev/null
+++ b/downscaling_ap5/postprocess/jupyter_notebooks/meta_postprocessing_rmse.ipynb
@@ -0,0 +1,190 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "b36a047f-7db7-4997-9959-3b4767f43345",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "%matplotlib inline\n",
+    "\n",
+    "import os\n",
+    "import numpy as np\n",
+    "import xarray as xr\n",
+    "import pandas as pd\n",
+    "\n",
+    "import matplotlib.pyplot as plt"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ae48851c-7d59-4128-af9d-97795e610aea",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "# auxiliary methods\n",
+    "\n",
+    "def create_lines_plot(data: xr.DataArray, data_std: xr.DataArray, model_names: str, metric: dict,\n",
+    "                      plt_fname: str, x_coord: str = \"hour\", **kwargs):\n",
+    "\n",
+    "    # get some plot parameters\n",
+    "    linestyle = kwargs.get(\"linestyle\", [\"k-\", \"b-\"])\n",
+    "    err_col = kwargs.get(\"error_color\", [\"grey\", \"blue\"])\n",
+    "    val_range = kwargs.get(\"value_range\", (0., 3.))\n",
+    "    fs = kwargs.get(\"fs\", 16)\n",
+    "    ref_line = kwargs.get(\"ref_line\", None)\n",
+    "    ref_linestyle = kwargs.get(\"ref_linestyle\", \"k--\")\n",
+    "    \n",
+    "    fig, (ax) = plt.subplots(1, 1)\n",
+    "    for i, exp in enumerate(data[\"exp\"]):\n",
+    "        ax.plot(data[x_coord].values, data.sel({\"exp\": exp}).values, linestyle[i],\n",
+    "                label=model_names[i])\n",
+    "        ax.fill_between(data[x_coord].values, data.sel({\"exp\": exp}).values-data_std.sel({\"exp\": exp}).values,\n",
+    "                        data.sel({\"exp\": exp}).values+data_std.sel({\"exp\": exp}).values, facecolor=err_col[i],\n",
+    "                        alpha=0.2)\n",
+    "    if ref_line is not None:\n",
+    "        nval = np.shape(data[x_coord].values)[0]\n",
+    "        ax.plot(data[x_coord].values, np.full(nval, ref_line), ref_linestyle)\n",
+    "    ax.set_ylim(*val_range)\n",
+    "    # label axis\n",
+    "    ax.set_xlabel(\"daytime [UTC]\", fontsize=fs)\n",
+    "    metric_name, metric_unit = list(metric.keys())[0], list(metric.values())[0]\n",
+    "    ax.set_ylabel(f\"{metric_name} T2m [{metric_unit}]\", fontsize=fs)\n",
+    "    ax.tick_params(axis=\"both\", which=\"both\", direction=\"out\", labelsize=fs-2)\n",
+    "    ax.legend(fontsize=fs-2, loc=\"upper right\")\n",
+    "\n",
+    "    # save plot and close figure\n",
+    "    plt_fname = plt_fname + \".png\" if not plt_fname.endswith(\".png\") else plt_fname\n",
+    "    print(f\"Save plot in file '{plt_fname}'\")\n",
+    "    plt.tight_layout()\n",
+    "    fig.savefig(plt_fname)\n",
+    "    plt.close(fig)\n",
+    "\n",
+    "def get_id_from_fname(fname):\n",
+    "    try:\n",
+    "        start_index = fname.find(\"id\") + 2            # Adding 2 to move past \"id\"\n",
+    "        end_index = fname.find(\"_\", start_index)\n",
+    "        \n",
+    "        exp_id = fname[start_index:end_index]\n",
+    "    except:\n",
+    "        raise ValueError(f\"Failed to deduce experiment ID from '{fname}'\")\n",
+    "        \n",
+    "    return exp_id"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6d57de46-706f-4386-8e14-1263a8c96666",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "# parameters\n",
+    "results_basedir = \"/p/home/jusers/langguth1/juwels/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/results\"\n",
+    "plt_dir = os.path.join(results_basedir, \"meta\")\n",
+    "\n",
+    "exp1 = \"wgan_t2m_atmorep_test\"\n",
+    "exp2 = \"atmorep_id26n32cey\"\n",
+    "\n",
+    "varname = \"T2m\"\n",
+    "year = 2018"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "9c2327f6-85e9-4bc4-acb6-21b110abcdc9",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "# main\n",
+    "os.makedirs(plt_dir, exist_ok=True)\n",
+    "\n",
+    "fexp1 = os.path.join(results_basedir, exp1, \"metric_files\", \"eval_rmse_year.csv\")\n",
+    "fexp2 = os.path.join(results_basedir, exp2, \"metric_files\", \"eval_rmse__small_dom_year.csv\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "0d5c40cd-586d-46b8-8069-2a14a6ab7eee",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "dims = [\"hour\", \"type\"]\n",
+    "coord_dict = {\"hour\": np.arange(24), \"type\": [\"mean\", \"std\"]}\n",
+    "\n",
+    "da_rmse_exp1 = xr.DataArray(pd.read_csv(fexp1, header=0, index_col=0), dims=dims, coords=coord_dict)\n",
+    "da_rmse_exp2 = xr.DataArray(pd.read_csv(fexp2, header=0, index_col=0), dims=dims, coords=coord_dict)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ca2993fb-f008-4ef4-bcff-b3a1745116ff",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "da_rmse_all = xr.concat([da_rmse_exp1, da_rmse_exp2], dim= \"exp\")\n",
+    "da_rmse_all = da_rmse_all.assign_coords({\"exp\": [exp1, exp2]})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6fd5217b-c59d-4979-81e7-790ff810af04",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "plt_fname = os.path.join(plt_dir, f\"eval_rmse_{exp1}_{exp2}.png\")\n",
+    "create_lines_plot(da_rmse_all.sel({\"type\": \"mean\"}), da_rmse_all.sel({\"type\": \"std\"}),\n",
+    "                  [\"WGAN\", \"AtmoRep\"], {\"RMSE\": \"K\"}, plt_fname)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "70d7870d-25f1-49de-a646-b1d062edc474",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "PyEarthSystem-2023.5",
+   "language": "python",
+   "name": "pyearthsystem"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/downscaling_ap5/postprocess/postprocess_wgan_era5_ifs.ipynb b/downscaling_ap5/postprocess/jupyter_notebooks/postprocess_wgan_era5_ifs.ipynb
similarity index 99%
rename from downscaling_ap5/postprocess/postprocess_wgan_era5_ifs.ipynb
rename to downscaling_ap5/postprocess/jupyter_notebooks/postprocess_wgan_era5_ifs.ipynb
index 6d42a70f7986cddb491eb2b43683e56f27c9e048..3b8a6ff2ce63be53959752f8f8b083689c7bce74 100644
--- a/downscaling_ap5/postprocess/postprocess_wgan_era5_ifs.ipynb
+++ b/downscaling_ap5/postprocess/jupyter_notebooks/postprocess_wgan_era5_ifs.ipynb
@@ -415,9 +415,9 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "langguth1_downscaling_kernel_juwels",
+   "display_name": "langguth1_downscaling_kernel",
    "language": "python",
-   "name": "langguth1_downscaling_kernel_juwels"
+   "name": "langguth1_downscaling_kernel"
   },
   "language_info": {
    "codemirror_mode": {
diff --git a/downscaling_ap5/postprocess/plotting.py b/downscaling_ap5/postprocess/plotting.py
index cb706c2dbdaf33bcf18a89186403eac3cef0a80f..ce09863fa5da0ca0b350d6066e6e4430b22daa4e 100644
--- a/downscaling_ap5/postprocess/plotting.py
+++ b/downscaling_ap5/postprocess/plotting.py
@@ -9,7 +9,7 @@ Methods for creating plots.
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-01-20"
-__update__ = "2022-12-08"
+__update__ = "2023-12-10"
 
 # for processing data
 import os
@@ -26,6 +26,9 @@ import cartopy.crs as ccrs
 from other_utils import provide_default
 
 # auxiliary variable for logger
+# auxiliary variable for logger
+logger_module_name = f"main_postprocess.{__name__}"
+module_logger = logging.getLogger(logger_module_name)
 module_name = os.path.basename(__file__).rstrip(".py")
 
 
@@ -50,7 +53,7 @@ def get_colormap_temp(levels=None):
 
 
 # for making plot nice
-def decorate_plot(ax_plot, plot_xlabel=True, plot_ylabel=True):
+def decorate_plot(ax_plot, plot_xlabel=True, plot_ylabel=True, extent=[3.5, 16.5, 44.5, 54.]):
     fs = 16
     # if "login" in host:
     # add nice coast- and borderlines
@@ -62,7 +65,7 @@ def decorate_plot(ax_plot, plot_xlabel=True, plot_ylabel=True):
     ax_plot.set_xticks(np.arange(0., 360. + 0.1, 5.))  # ,crs=projection_crs)
     ax_plot.set_yticks(np.arange(-90., 90. + 0.1, 5.))  # ,crs=projection_crs)
 
-    ax_plot.set_extent([3.5, 16.5, 44.5, 54.])    # , crs=prj_crs)
+    ax_plot.set_extent(extent)    # , crs=prj_crs)
     ax_plot.minorticks_on()
     ax_plot.tick_params(axis="both", which="both", direction="out", labelsize=fs)
 
@@ -127,7 +130,7 @@ def create_map_score(score, plt_fname, **kwargs):
     func_logger = logging.getLogger(f"postprocess.{module_name}.{create_map_score.__name__}")
 
     # get keywor arguments
-    score_dims = kwargs.get("score_dims", ["lat", "lon"])
+    dims = kwargs.get("dims", ["lat", "lon"])
     title = kwargs.get("title", "Score")
     levels = kwargs.get("levels", np.arange(-5., 5., 0.5))
     # auxiliary variables
@@ -135,17 +138,22 @@ def create_map_score(score, plt_fname, **kwargs):
     nbounds = len(lvl)
     cmap = kwargs.get("cmap", mpl.cm.PuOr_r(np.linspace(0., 1., nbounds)))
     fs = kwargs.get("fs", 16)
-    projection = kwargs.get("projection", ccrs.PlateCarree())
+    projection = kwargs.get("projection", ccrs.RotatedPole(pole_longitude=-162.0, pole_latitude=39.25))
+    extent = kwargs.get("extent", None)
+    
+    decorate_dict = {}
+    if extent:
+        decorate_dict["extent"] = extent
     
     # get coordinate data
     try:
-        lat, lon = score[score_dims[0]].values, score[score_dims[1]].values
+        lat, lon = score[dims[0]].values, score[dims[1]].values
     except Exception as err:
         print("Failed to retrieve coordinates from score-data")
         raise err
     # construct array for edges of grid points
-    dy, dx = np.round((lat[1] - lat[0]), 3), np.round((lon[1] - lon[0]), 3)
-    lat_e, lon_e = np.arange(lat[0]-dy/2, lat[-1]+dy, dy), np.arange(lon[0]-dx/2, lon[-1]+dx, dx)
+    dy, dx = np.round((lat[1] - lat[0]), 4), np.round((lon[1] - lon[0]), 4)
+    lat_e, lon_e = np.arange(lat[0]-dy/2, lat[-1]+dy, dy), np.arange(lon[0]-dx/2, lon[-1]+dx, dx)  
 
     # get colormap
     # create colormap and corresponding norm
@@ -158,7 +166,7 @@ def create_map_score(score, plt_fname, **kwargs):
     plt1 = ax.pcolormesh(lon_e, lat_e, np.squeeze(score.values), cmap=cmap_obj, norm=norm, 
                          transform=projection)
 
-    ax = decorate_plot(ax)
+    ax = decorate_plot(ax, **decorate_dict)
 
     ax.set_title(title, size=fs)
 
@@ -208,3 +216,66 @@ def create_line_plot(data: xr.DataArray, data_std: xr.DataArray, model_name: str
     plt.tight_layout()
     fig.savefig(plt_fname)
     plt.close(fig)
+
+
+# write the create_box_plot function
+def create_box_plot(data, plt_fname: str, **plt_kwargs):
+    """
+    Create box plot of feature importance scores
+    :param feature_scores: Feature importance scores with predictors as firstdimension and time as second dimension
+    :param plt_fname: File name of plot
+    """    
+    func_logger = logging.getLogger(f"postprocess.{module_name}.{create_box_plot.__name__}")
+
+    # get some plot parameters
+    val_range = plt_kwargs.get("value_range", [None])
+    widths = plt_kwargs.get("widths", None)
+    colors = plt_kwargs.get("colors", None)
+    fs = plt_kwargs.get("fs", 16)
+    ref_line = plt_kwargs.get("ref_line", 1.)
+    ref_linestyle = plt_kwargs.get("ref_linestyle", "k-")
+    title = plt_kwargs.get("title", "")
+    ylabel = plt_kwargs.get("ylabel", "")
+    xlabel = plt_kwargs.get("xlabel", "")
+    yticks = plt_kwargs.get("yticks", None)
+    labels = plt_kwargs.get("labels", None)
+    
+    # create box whiskers plot with matplotlib
+    fig, ax = plt.subplots(figsize=(12, 8))
+
+    bp = plt.boxplot(data, widths=widths, labels=labels, patch_artist=True)
+    
+    # modify fliers
+    fliers = bp['fliers'] 
+    for i in range(len(fliers)): # iterate through the Line2D objects for the fliers for each boxplot
+        box = fliers[i] # this accesses the x and y vectors for the fliers for each box 
+        box.set_data([[box.get_xdata()[0]],[np.max(box.get_ydata())]])
+        
+    if ref_line is not None:
+        nval = len(fliers)
+        ax.plot(np.array(range(0, nval+1)) + 0.5, np.full(nval+1, ref_line), ref_linestyle)
+        
+    if colors is None:
+        pass
+    else:
+        if isinstance(colors, str): colors = len(bp["boxes"])*[colors]
+        for patch, color in zip(bp['boxes'], colors):
+            patch.set_facecolor(color)
+    
+    ax.set_ylim(*val_range)
+    ax.set_yticks(yticks)    
+    
+    ax.set_title(title, fontsize=fs + 2)
+    ax.set_ylabel(ylabel, fontsize=fs, labelpad=8)
+    ax.set_xlabel(xlabel, fontsize=fs, labelpad=8)
+    ax.tick_params(axis="both", which="both", direction="out", labelsize=fs-2)
+    ax.yaxis.grid(True)
+
+    # save plot
+    plt.tight_layout()
+    plt.savefig(plt_fname + ".png" if not plt_fname.endswith(".png") else plt_fname)
+    plt.close(fig)
+
+    func_logger.info(f"Feature importance scores saved to {plt_fname}.")
+    
+    return True
diff --git a/downscaling_ap5/postprocess/postprocess.py b/downscaling_ap5/postprocess/postprocess.py
index bac9c6ab6ce2ca33c88344537ed499a5d8fb883c..3b6b74d751404c5e0753b055f7bff845845206d8 100644
--- a/downscaling_ap5/postprocess/postprocess.py
+++ b/downscaling_ap5/postprocess/postprocess.py
@@ -9,13 +9,20 @@ Auxiliary methods for postprocessing.
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-12-08"
-__update__ = "2022-12-08"
+__update__ = "2023-10-12"
 
 import os
+from typing import Union, List
 import logging
+import numpy as np
 import xarray as xr
 from cartopy import crs
-from plotting import create_line_plot, create_map_score
+from statistical_evaluation import feature_importance
+from plotting import create_line_plot, create_map_score, create_box_plot
+
+# basic data types
+da_or_ds = Union[xr.DataArray, xr.Dataset]
+list_or_str = Union[List[str], str]
 
 # auxiliary variable for logger
 logger_module_name = f"main_postprocess.{__name__}"
@@ -28,18 +35,19 @@ def get_model_info(model_base, output_base: str, exp_name: str, bool_last: bool
     func_logger = logging.getLogger(f"{logger_module_name}.{get_model_info.__name__}")
 
     model_name = os.path.basename(model_base)
+    norm_dir = model_base
+
+
+    add_str = "_last" if bool_last else "_best"
 
     if "wgan" in exp_name:
         func_logger.debug(f"WGAN-modeltype detected.")
-        add_str = "_last" if bool_last else ""
         model_dir, plt_dir = os.path.join(model_base, f"{exp_name}_generator{add_str}"), \
                              os.path.join(output_base, model_name)
-        norm_dir = model_base
         model_type = "wgan"
     elif "unet" in exp_name or "deepru" in exp_name:
         func_logger.debug(f"U-Net-modeltype detected.")
-        model_dir, plt_dir = model_base, os.path.join(output_base, model_name)
-        norm_dir = model_dir
+        model_dir, plt_dir = os.path.join(model_base, f"{exp_name}{add_str}"), os.path.join(output_base, model_name)
         model_type = "unet" if "unet" in exp_name else "deepru"
     else:
         func_logger.debug(f"Model type could not be inferred from experiment name. Try my best by defaulting...")
@@ -53,6 +61,29 @@ def get_model_info(model_base, output_base: str, exp_name: str, bool_last: bool
     return model_dir, plt_dir, norm_dir, model_type
 
 
+def run_feature_importance(da: xr.DataArray, predictors: list_or_str, varname_tar: str, model, norm, score_name: str,
+                           ref_score: float, data_loader_opt: dict, plt_dir: str, patch_size = (6, 6), variable_dim = "variable"):
+
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{run_feature_importance.__name__}")
+    
+    # get feature importance scores
+    feature_scores = feature_importance(da, predictors, varname_tar, model, norm, score_name, data_loader_opt, 
+                                        patch_size=patch_size, variable_dim=variable_dim)
+    
+    rel_changes = feature_scores / ref_score
+    max_rel_change = int(np.ceil(np.amax(rel_changes) + 1.))
+
+    # plot feature importance scores in a box-plot with whiskers where each variable is a box
+    plt_fname = os.path.join(plt_dir, f"feature_importance_{score_name}.png")
+
+    create_box_plot(rel_changes.T, plt_fname, **{"title": f"Feature Importance ({score_name.upper()})", "ref_line": 1., "widths": .3, 
+                                                 "xlabel": "Predictors", "ylabel": f"Rel. change {score_name.upper()}", "labels": predictors, 
+                                                 "yticks": range(1, max_rel_change), "colors": "b"})
+
+    return feature_scores
+
+
 def run_evaluation_time(score_engine, score_name: str, score_unit: str, plot_dir: str, **plt_kwargs):
     """
     Create line plots of desired evaluation metric. Evaluation metric must have a time-dimension
@@ -64,79 +95,102 @@ def run_evaluation_time(score_engine, score_name: str, score_unit: str, plot_dir
     # get local logger
     func_logger = logging.getLogger(f"{logger_module_name}.{run_evaluation_time.__name__}")
 
+    # create output-directories if necessary 
+    metric_dir = os.path.join(plot_dir, "metric_files")
     os.makedirs(plot_dir, exist_ok=True)
+    os.makedirs(metric_dir, exist_ok=True)
+    
     model_type = plt_kwargs.get("model_type", "wgan")
 
     func_logger.info(f"Start evaluation in terms of {score_name}")
     score_all = score_engine(score_name)
+    score_all = score_all.drop_vars("variables")
 
     func_logger.info(f"Globally averaged {score_name}: {score_all.mean().values:.4f} {score_unit}, " +
-                     f"standard deviation: {score_all.std().values:.4f}")
-
+                     f"standard deviation: {score_all.std().values:.4f}")  
+    
     score_hourly_all = score_all.groupby("time.hour")
     score_hourly_mean, score_hourly_std = score_hourly_all.mean(), score_hourly_all.std()
-    for hh in range(24):
-        func_logger.debug(f"Evaluation for {hh:02d} UTC")
-        if hh == 0:
-            tmp = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season")
-            score_hourly_mean_sea, score_hourly_std_sea = tmp.mean().copy(), tmp.std().copy()
-        else:
-            tmp = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season")
-            score_hourly_mean_sea, score_hourly_std_sea = xr.concat([score_hourly_mean_sea, tmp.mean()], dim="hour"), \
-                                                          xr.concat([score_hourly_std_sea, tmp.std()], dim="hour")
 
     # create plots
     create_line_plot(score_hourly_mean, score_hourly_std, model_type.upper(),
                      {score_name.upper(): score_unit},
                      os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}.png"), **plt_kwargs)
 
-    for sea in score_hourly_mean_sea["season"]:
+    scores_to_csv(score_hourly_mean, score_hourly_std, score_name, fname=os.path.join(metric_dir, f"eval_{score_name}_year.csv"))
+
+    score_seas = score_all.groupby("time.season")
+    for sea, score_sea in score_seas:
+        score_sea_hh = score_sea.groupby("time.hour")
+        score_sea_hh_mean, score_sea_hh_std = score_sea_hh.mean(), score_sea_hh.std()
         func_logger.debug(f"Evaluation for season '{sea}'...")
-        create_line_plot(score_hourly_mean_sea.sel({"season": sea}),
-                         score_hourly_std_sea.sel({"season": sea}),
+        create_line_plot(score_sea_hh_mean,
+                         score_sea_hh.std(),
                          model_type.upper(), {score_name.upper(): score_unit},
-                         os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{sea.values}.png"),
+                         os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{sea}.png"),
                          **plt_kwargs)
-    return True
+        
+        scores_to_csv(score_sea_hh_mean, score_sea_hh_std, score_name, 
+                      fname=os.path.join(metric_dir, f"eval_{score_name}_{sea}.csv"))
+    return score_all
 
 
-def run_evaluation_spatial(score_engine, score_name: str, plot_dir: str, **plt_kwargs):
+def run_evaluation_spatial(score_engine, score_name: str, plot_dir: str, 
+                           dims = ["rlat", "rlon"], **plt_kwargs):
     """
     Create map plots of desired evaluation metric. Evaluation metric must be given in rotated coordinates.
-    To-Do: Add flexibility regarding underlying coordinate data (i.e. projection).
     :param score_engine: Score engine object to comput evaluation metric
-    :param score_name: Name of evaluation metric (must be implemented into score_engine)
     :param plot_dir: Directory to save plot files
+    :param dims: Spatial dimension names
     """
     # get local logger
-    func_logger = logging.getLogger(f"{logger_module_name}.{run_evaluation_spatial.__name__}")
-
+    func_logger = logging.getLogger(f"{logger_module_name}.{run_evaluation_time.__name__}")
+    
     os.makedirs(plot_dir, exist_ok=True)
 
     model_type = plt_kwargs.get("model_type", "wgan")
     score_all = score_engine(score_name)
-    cosmo_prj = crs.RotatedPole(pole_longitude=-162.0, pole_latitude=39.25)
+    score_all = score_all.drop_vars("variables")
 
     score_mean = score_all.mean(dim="time")
     fname = os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_avg_map.png")
-    create_map_score(score_mean, fname, score_dims=["rlat", "rlon"],
-                     title=f"{score_name.upper()} (avg.)", projection=cosmo_prj, **plt_kwargs)
+    create_map_score(score_mean, fname, dims=dims,
+                     title=f"{score_name.upper()} (avg.)", **plt_kwargs)
 
     score_hourly_mean = score_all.groupby("time.hour").mean(dim=["time"])
     for hh in range(24):
         func_logger.debug(f"Evaluation for {hh:02d} UTC")
         fname = os.path.join(plot_dir, f"downscaling_{model_type}_{score_name.lower()}_{hh:02d}_map.png")
         create_map_score(score_hourly_mean.sel({"hour": hh}), fname,
-                         score_dims=["rlat", "rlon"], title=f"{score_name.upper()} {hh:02d} UTC",
-                         projection=cosmo_prj, **plt_kwargs)
+                         dims=dims, title=f"{score_name.upper()} {hh:02d} UTC",
+                         **plt_kwargs)
 
     for hh in range(24):
         score_now = score_all.isel({"time": score_all.time.dt.hour == hh}).groupby("time.season").mean(dim="time")
         for sea in score_now["season"]:
-            func_logger.debug(f"Evaluation for season '{str(sea)}' at {hh:02d} UTC")
+            func_logger.debug(f"Evaluation for season '{str(sea.values)}' at {hh:02d} UTC")
             fname = os.path.join(plot_dir,
                                  f"downscaling_{model_type}_{score_name.lower()}_{sea.values}_{hh:02d}_map.png")
-            create_map_score(score_now.sel({"season": sea}), fname, score_dims=["rlat", "rlon"],
-                             title=f"{score_name} {sea.values} {hh:02d} UTC", projection=cosmo_prj, **plt_kwargs)
+            create_map_score(score_now.sel({"season": sea}), fname, dims=dims,
+                             title=f"{score_name} {sea.values} {hh:02d} UTC", **plt_kwargs)
 
     return True
+
+
+def scores_to_csv(score_mean, score_std, score_name, fname="scores.csv"):
+    """
+    Save scores to csv file
+    :param score_mean: Hourly mean of score
+    :param score_std: Hourly standard deviation of score
+    :param score_name: Name of score
+    :param fname: Filename of csv file
+    """
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{scores_to_csv.__name__}")
+    
+    df_mean = score_mean.to_dataframe(name=f"{score_name}_mean")
+    df_std = score_std.to_dataframe(name=f"{score_name}_std")
+    df = df_mean.join(df_std)
+
+    func_logger.info(f"Save values of {score_name} to {fname}...")
+    df.to_csv(fname)
diff --git a/downscaling_ap5/postprocess/statistical_evaluation.py b/downscaling_ap5/postprocess/statistical_evaluation.py
index 1c767e2d312221ff7b74f207a1a54a5c18465f00..a555fc5e8ce0bfc88659b16da4413c9d75205bbe 100644
--- a/downscaling_ap5/postprocess/statistical_evaluation.py
+++ b/downscaling_ap5/postprocess/statistical_evaluation.py
@@ -8,22 +8,31 @@ Collection of auxiliary functions for statistical evaluation and class for Score
 
 __email__ = "m.langguth@fz-juelich.de"
 __author__ = "Michael Langguth"
-__date__ = "2022-09-11"
+__date__ = "2023-10-10"
 
-import numpy as np
-import xarray as xr
 from typing import Union, List
-import datetime
-import pandas as pd
 try:
     from tqdm import tqdm
     l_tqdm = True
 except:
     l_tqdm = False
+import logging
+import numpy as np
+import pandas as pd
+import xarray as xr
+from skimage.util.shape import view_as_blocks
+from handle_data_class import HandleDataClass
+from model_utils import convert_to_xarray
 from other_utils import provide_default, check_str_in_list
 
+
 # basic data types
 da_or_ds = Union[xr.DataArray, xr.Dataset]
+list_or_str = Union[List[str], str]
+
+# auxiliary variable for logger
+logger_module_name = f"main_postprocess.{__name__}"
+module_logger = logging.getLogger(logger_module_name)
 
 
 def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, factorization="calibration_refinement",
@@ -36,21 +45,30 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa
     :param quantiles: conditional quantiles
     :return quantile_panel: conditional quantiles of p(m|o) or p(o|m)
     """
-    method = calculate_cond_quantiles.__name__
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{calculate_cond_quantiles.__name__}")
 
     # sanity checks
     if not isinstance(data_fcst, xr.DataArray):
-        raise ValueError("%{0}: data_fcst must be a DataArray.".format(method))
+        err_mess = f"data_fcst must be a DataArray, but is of type '{type(data_fcst)}'."
+        func_logger.error(err_mess, stack_info=True, exc_info=True)
+        raise ValueError(err_mess)
 
     if not isinstance(data_ref, xr.DataArray):
-        raise ValueError("%{0}: data_ref must be a DataArray.".format(method))
+        err_mess = f"data_ref must be a DataArray, but is of type '{type(data_ref)}'."
+        func_logger.error(err_mess, stack_info=True, exc_info=True)
+        raise ValueError(err_mess)
 
     if not (list(data_fcst.coords) == list(data_ref.coords) and list(data_fcst.dims) == list(data_ref.dims)):
-        raise ValueError("%{0}: Coordinates and dimensions of data_fcst and data_ref must be the same".format(method))
+        err_mess = f"Coordinates and dimensions of data_fcst and data_ref must be the same."
+        func_logger.error(err_mess, stack_info=True, exc_info=True)
+        raise ValueError(err_mess)
 
     nquantiles = len(quantiles)
     if not nquantiles >= 3:
-        raise ValueError("%{0}: quantiles must be a list/tuple of at least three float values ([0..1])".format(method))
+        err_mess = f"Quantiles must be a list/tuple of at least three float values ([0..1])."
+        func_logger.error(err_mess, stack_info=True, exc_info=True)
+        raise ValueError(err_mess)
 
     if factorization == "calibration_refinement":
         data_cond = data_fcst
@@ -59,8 +77,9 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa
         data_cond = data_ref
         data_tar = data_fcst
     else:
-        raise ValueError("%{0}: Choose either 'calibration_refinement' or 'likelihood-base_rate' for factorization"
-                         .format(method))
+        err_mess = f"Choose either 'calibration_refinement' or 'likelihood-base_rate' for factorization"
+        func_logger.error(err_mess, stack_info=True, exc_info=True)
+        raise ValueError(err_mess)
 
     # get and set some basic attributes
     data_cond_longname = provide_default(data_cond.attrs, "longname", "conditioning_variable")
@@ -87,7 +106,7 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa
                                   attrs={"cond_var_name": data_cond_longname, "cond_var_unit": data_cond_unit,
                                          "tar_var_name": data_tar_longname, "tar_var_unit": data_tar_unit})
     
-    print("%{0}: Start caclulating conditional quantiles for all {1:d} bins.".format(method, nbins))
+    func_logger.info(f"Start caclulating conditional quantiles for all {nbins:d} bins.")
     # fill the quantile data array
     for i in np.arange(nbins):
         # conditioning of ground truth based on forecast
@@ -97,6 +116,37 @@ def calculate_cond_quantiles(data_fcst: xr.DataArray, data_ref: xr.DataArray, fa
 
     return quantile_panel, data_cond
 
+def get_cdf_of_x(sample_in, prob_in):
+    """
+    Wrappper for interpolating CDF-value for given data
+    :param sample_in : input values to derive discrete CDF
+    :param prob_in   : corresponding CDF
+    :return: lambda function converting arbitrary input values to corresponding CDF value
+    """
+    return lambda xin: np.interp(xin, sample_in, prob_in)
+
+def get_seeps_matrix(seeps_param):
+    """
+    Converts SEEPS paramter array to SEEPS matrix.
+    :param seeps_param: Array providing p1 and p3 parameters of SEEPS weighting matrix.
+    :return seeps_matrix: 3x3 weighting matrix for the SEEPS-score
+    """
+    # initialize matrix
+    seeps_weights = xr.full_like(seeps_param["p1"], np.nan)
+    seeps_weights = seeps_weights.expand_dims(dim={"weights":np.arange(9)}, axis=0).copy()
+    seeps_weights.name = "SEEPS weighting matrix"
+    
+    # off-diagonal elements
+    seeps_weights[{"weights": 1}] = 1./(1. - seeps_param["p1"])
+    seeps_weights[{"weights": 2}] = 1./seeps_param["p3"] + 1./(1. - seeps_param["p1"])
+    seeps_weights[{"weights": 3}] = 1./seeps_param["p1"]
+    seeps_weights[{"weights": 5}] = 1./seeps_param["p3"]
+    seeps_weights[{"weights": 6}] = 1./seeps_param["p1"] + 1./(1. - seeps_param["p3"])
+    seeps_weights[{"weights": 7}] = 1./(1. - seeps_param["p3"])
+    # diagnol elements
+    seeps_weights[{"weights": [0, 4, 8]}] = xr.where(np.isnan(seeps_weights[{"weights": 7}]), np.nan, 0.)
+    
+    return seeps_weights
 
 def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length: int, nboots_block: int = 1000,
                                    seed: int = 42):
@@ -109,14 +159,18 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length
     :param seed: seed for random block sampling (to be held constant for reproducability)
     :return: bootstrapped version of metric(-s)
     """
-
-    method = perform_block_bootstrap_metric.__name__
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{perform_block_bootstrap_metric.__name__}")
 
     if not isinstance(metric, da_or_ds.__args__):
-        raise ValueError("%{0}: Input metric must be a xarray DataArray or Dataset and not {1}".format(method,
-                                                                                                       type(metric)))
+        err_mess = f"Input metric must be a xarray DataArray or Dataset and not {type(metric)}."
+        func_logger.error(err_mess, stack_info=True, exc_info=True)
+        raise ValueError(err_mess)
+    
     if dim_name not in metric.dims:
-        raise ValueError("%{0}: Passed dimension cannot be found in passed metric.".format(method))
+        err_mess = f"Passed dimension {dim_name} cannot be found in passed metric."
+        func_logger.error(err_mess, stack_info=True, exc_info=True)
+        raise ValueError(err_mess)
 
     metric = metric.sortby(dim_name)
 
@@ -124,8 +178,9 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length
     nblocks = int(np.floor(dim_length/block_length))
 
     if nblocks < 10:
-        raise ValueError("%{0}: Less than 10 blocks are present with given block length {1:d}."
-                         .format(method, block_length) + " Too less for bootstrapping.")
+        err_mess = f"Less than 10 blocks are present with given block length {block_length:d}. Too less for bootstrapping."
+        func_logger.error(err_mess, stack_info=True, exc_info=True)
+        raise ValueError(err_mess)
 
     # precompute metrics of block
     for iblock in np.arange(nblocks):
@@ -143,7 +198,7 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length
     np.random.seed(seed)
     iblocks_boot = np.sort(np.random.randint(nblocks, size=(nboots_block, nblocks)))
 
-    print("%{0}: Start block bootstrapping...".format(method))
+    func_logger.info("Start block bootstrapping...")
     iterator_b = np.arange(nboots_block)
     if l_tqdm:
         iterator_b = tqdm(iterator_b)
@@ -163,6 +218,281 @@ def perform_block_bootstrap_metric(metric: da_or_ds, dim_name: str, block_length
     return metric_boot
 
 
+def get_domain_info(da: xr.DataArray, lonlat_dims: list =["lon", "lat"], re:float = 6378*1.e+03): 
+    """
+    Get information about the underlying grid of a DataArray.
+    Assumes a regular, spherical grid (can also be a rotated one of lonlat_dims are adapted accordingly)
+    :param da: The xrray DataArray given on a regular, spherical grid (i.e. providing latitude and longitude coordinates)
+    :param lonlat_dims: Names of the longutude and latitude coordinates
+    :param re: radius of spherical Earth 
+    :return grid_dict: dictionary providing dx (grid spacing), nx (#gridpoints) and Lx(domain length) (same for y) as well as lat0 (central latitude)
+    """
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{get_domain_info.__name__}")
+
+    lon, lat = da[lonlat_dims[0]], da[lonlat_dims[1]]
+
+    try:
+        assert lon.ndim, f"Longitude data must be a 1D-array, but is a {lon.ndim:d}D-array"
+    except AssertionError as e:
+        func_logger.error(e, stack_info=True, exc_info=True)
+        raise e
+    
+    try:
+        assert lat.ndim, f"Latitude data must be a 1D-array, but is a {lat.ndim:d}D-array"
+    except AssertionError as e:
+        func_logger.error(e, stack_info=True, exc_info=True)
+        raise e 
+
+    lat0 = np.mean(lat)
+    nx, ny = len(lon), len(lat)
+    dx, dy = (lon[1] - lon[0]).values, (lat[1] - lat[0]).values 
+
+    deg2m = re*2*np.pi/360.
+    Lx, Ly = np.abs(nx*dx*deg2m)*np.cos(np.deg2rad(lat0)), np.abs(ny*dy*deg2m)
+
+    grid_dict = {"nx": nx, "ny": ny, "dx": dx, "dy": dy, 
+                 "Lx": Lx, "Ly": Ly, "lat0": lat0}
+    
+    return grid_dict
+
+
+def detrend_data(da: xr.DataArray, xy_dims: list =["lon", "lat"]):
+    """
+    Detrend data on a limited area domain to majke it periodic in horizontal directions.
+    Method based on Errico, 1985.
+    :param da: The data given on a regular (spherical) grid
+    :param xy_dims: Names of horizontal dimensions
+    :return detrended, periodic data:
+    """
+    
+    x_dim, y_dim = xy_dims[0], xy_dims[1]
+    nx, ny = len(da[x_dim]), len(da[y_dim])
+    
+    # remove trend in x-direction
+    fac_x = xr.full_like(da, 1.) * xr.DataArray(2*np.arange(nx) - nx, dims=x_dim, coords={x_dim: da[x_dim]})
+    fac_y = xr.full_like(da, 1.)* xr.DataArray(2*np.arange(ny) - ny, dims=y_dim, coords={y_dim: da[y_dim]})
+    trend_x, _ = xr.broadcast((da.isel({x_dim: -1}) - da.isel({x_dim: 0}))/float(nx-1), da)
+    da = da - 0.5 * trend_x*fac_x
+    # remove trend in y-direction
+    trend_y, _ = xr.broadcast((da.isel({y_dim: -1}) - da.isel({y_dim: 0}))/float(ny-1), da)
+    da = da - 0.5 * trend_y*fac_y
+    
+    return da
+
+
+def angular_integration(da_fft, grid_dict: dict, lcutoff: bool = True):
+    """
+    Get power spectrum as a function of the total wavenumber.
+    The integration in the (k_x, k_y)-plane is done by summation over (k_x, k_y)-pairs lying in annular rings, cf. Durran et al., 2017
+    :param da_fft: Fast Fourier transformed data with (lat, lon) as last two dimensions
+    :param grid_dict: dictionary providing information on underlying grid (generated by get_domain_info-method)
+    :param lcutoff: flag if spectrum should be truncated to cutoff frequency or if full spectrum should be returned (False)
+    :return power spectrum in total wavenumber space.
+    """
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{angular_integration.__name__}")
+
+    sh = da_fft.shape
+    nx, ny = grid_dict["nx"], grid_dict["ny"]
+    dk = np.array([2.*np.pi/grid_dict["Lx"], 2.*np.pi/grid_dict["Ly"]])
+    idkh = np.argmax(dk)
+    dkx, dky = dk
+    nmax = int(np.round(np.sqrt(np.square(nx*dkx) + np.square(ny*dky))/dk[idkh])) 
+
+    sh = (*sh[:-2], nmax) if da_fft.ndim >= 2 else (1, nmax)    # add singleton dim if da_fft is a 2D-array only 
+    spec_radial = np.zeros(sh, dtype="float")
+    
+    # start angular integration
+    func_logger.info(f"Start angular integration for {nmax:d} wavenumbers.")
+    for i in range(nx):
+        for j in range(ny):
+            k_now = int(np.round(np.sqrt(np.square(i*dkx) + np.square(j*dky))/dk[idkh]))
+            spec_radial[..., k_now] += da_fft[..., i, j].real**2 + da_fft[..., i, j].imag**2
+
+    if lcutoff:
+        # Cutting/truncating the spectrum is required to ensure that the combinations of kx and ky
+        # to yield the total wavenumber are complete. Without this truncation, the spectrum gets distorted 
+        # as argued in Errico, 1985 (cf. Eq. 6 therein).
+        # Note that Errico, 1985 chooses the maximum of (nx, ny) since dk is defined as dk=min(kx, ky).
+        # Here, we choose dk = max(dkx, dky) following Durran et al., 2017 (see Eq. 18 therein), and thus have to choose min(nx, ny).
+        cutoff = int(np.round(min(np.array([nx, ny]))/2 + 0.01))
+        spec_radial = spec_radial[..., :cutoff]                
+
+    return np.squeeze(spec_radial)
+
+
+def get_spectrum(da: xr.DataArray, lonlat_dims = ["lon", "lat"], lcutoff: bool = True, re: float = 6378*1e+03):
+    """
+    Compute power spectrum in terms of total wavenumber from numpy-array.
+    Note: Assumes a regular, spherical grid with dx=dy.
+    :param da: DataArray with (lon, lat)-like dimensions
+    :param lcutoff: flag if spectrum should be truncated to cutoff frequency or if full spectrum should be returned (False)
+    :param re: (spherical) Earth radius
+    :return var_rad: power spectrum in terms of wavenumber
+    """
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{get_spectrum.__name__}")
+
+    # sanity check on # dimensions
+    grid_dict = get_domain_info(da, lonlat_dims, re=re)
+    nx, ny = grid_dict["nx"], grid_dict["ny"]
+    lx, ly = grid_dict["Lx"], grid_dict["Ly"]
+
+    da = da.transpose(..., *lonlat_dims)
+    # detrend data to get periodic boundary values (cf. Errico, 1985)
+    da_var = detrend_data(da, xy_dims=lonlat_dims)
+    # ... and apply FFT
+    func_logger.info("Start computing Fourier transformation")
+    fft_var = np.fft.fft2(da_var.values)/float(nx*ny)
+
+    var_rad = angular_integration(fft_var, {"nx": nx, "ny": ny, "Lx": lx, "Ly": ly}, lcutoff)
+
+    return var_rad
+
+
+def sample_permut_xyt(da_orig: xr.DataArray, patch_size:tuple = (6, 6)):
+    """
+    Permutes sample in a spatio-temporal way following the method of Breiman (2001). 
+    The concrete implementation follows Höhlein et al., 2020 with spatial permutation based on patching.
+    Note that the latter allows to handle time-invariant data.
+    :param da_orig: original sample. Must be 3D with a 'time'-dimension 
+    :param patch_size: tuple for patch size
+    :return: spatio-temporally permuted sample 
+    """
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{sample_permut_xyt.__name__}")
+
+    try:
+        assert da_orig.ndim == 3, f"da_orig must be a 3D-array, but has {da_orig.ndim} dimensions."
+    except AssertionError as e:
+        func_logger.error(e, stack_info=True, exc_info=True)
+        raise e
+
+    coords_orig = da_orig.coords
+    dims_orig = da_orig.dims
+    sh_orig = da_orig.shape
+
+    # temporal permutation
+    func_logger.info(f"Start spatio-temporal permutation for sample with shape {sh_orig}.")
+
+    ntimes = len(da_orig["time"])
+    if dims_orig != "time":
+        da_orig = da_orig.transpose("time", ...)
+        coords_now, dims_now, sh_now = da_orig.coords, da_orig.dims, da_orig.shape
+    else:
+        coords_now, dims_now, sh_now = coords_orig, dims_orig, sh_orig
+
+    da_permute = np.random.permutation(da_orig).copy()
+    da_permute = xr.DataArray(da_permute, coords=coords_now, dims=dims_now)
+
+    # spatial permutation with patching
+    # time must be last dimension (=channel dimension)
+    sh_aux = da_permute.transpose(..., "time").shape
+
+    # Note that the order of x- and y-coordinates does not matter here
+    da_patched = view_as_blocks(da_permute.transpose(..., "time").values, block_shape=(*patch_size, ntimes))
+
+    # convert to DataArray
+    sh = da_patched.shape
+    dims = ["pat_x", "pat_y", "dummy", "ix", "iy", "time"]
+
+    da_patched = xr.DataArray(da_patched, coords={dims[0]: np.arange(sh[0]), dims[1]: np.arange(sh[1]), "dummy": range(1),
+                                                  dims[3]: np.arange(sh[3]), dims[4]: np.arange(sh[4]), "time": da_permute["time"]}, 
+                              dims=dims)
+    
+    # stack xy-patches and permute
+    da_patched = da_patched.stack({"pat_xy": ["pat_x", "pat_y"]})
+    da_patched[...] = np.random.permutation(da_patched.transpose()).transpose()
+    
+    # unstack
+    da_patched = da_patched.unstack().transpose(*dims)
+
+    # revert view_as_blocks-opertaion
+    da_patched = da_patched.values.transpose([0, 3, 1, 4, 2, 5]).reshape(sh_aux)
+    
+    # write data back on da_permute
+    da_permute[...] = np.moveaxis(da_patched, 2, 0)
+
+    # transpose to original dimension ordering if required
+    da_permute = da_permute.transpose(*dims_orig)
+
+    return da_permute
+
+
+def feature_importance(da: xr.DataArray, predictors: list_or_str, varname_tar: str, model, norm, score_name: str,
+                       data_loader_opt: dict, patch_size = (6, 6), variable_dim = "variable"):
+    """
+    Run featiure importance analysis based on permutation method (see signature of sample_permut_xyt-method)
+    :param da: The (test-)data provided in DataArray with variable-dimension
+    :param predictors: List of predictor variables
+    :param varname_tar: Name of target variable
+    :param model: Trained model for inference
+    :param norm: Normalization object
+    :param score_name: Name of metric-score to be calculated
+    :param data_loader_opt: Dictionary providing options for the TensorFlow data pipeline
+    :param patch_size: Tuple for patch size during spatio-temporal permutation
+    :param variable_dim: Name of variable dimension
+    :return score_all: DataArray with scores for all predictor variables
+    """
+    # get local logger
+    func_logger = logging.getLogger(f"{logger_module_name}.{feature_importance.__name__}")
+
+    # sanity checks
+    _ = check_str_in_list(list(da[variable_dim]), predictors)
+    try:
+        assert da.dims[0] == "time", f"First dimension of the data must be a time-dimensional, but is {da.dims[0]}."
+    except AssertionError as e:
+        func_logger.error(e, stack_info=True, exc_info=True)
+        raise e
+
+    ntimes = len(da["time"])
+
+    # get ground truth data and underlying metadata
+    ground_truth = norm.denormalize(da.sel({variable_dim: varname_tar}), varname=varname_tar)
+
+    # initialize score-array
+    score_all = xr.DataArray(np.zeros((len(predictors), ntimes)), coords={"predictor": predictors, "time": da["time"]},
+                             dims=["predictor", "time"])
+
+    for var in predictors:
+        func_logger.info(f"Run sample importance analysis for {var}...")
+        # get copy of sample array
+        da_copy = da.copy(deep=True)
+        # permute sample
+        da_permut = sample_permut_xyt(da.sel({variable_dim: var}).copy(), patch_size=patch_size)
+        da_copy.loc[{variable_dim: var}] = da_permut
+        
+        # get TF dataset
+        func_logger.info(f"Set-up data pipeline with permuted sample for {var}...")
+        tfds_test = HandleDataClass.make_tf_dataset_allmem(da_copy, **data_loader_opt)
+
+        # predict
+        func_logger.info(f"Run inference with permuted sample for {var}...")
+        y_pred = model.predict(tfds_test, verbose=2)
+
+        # convert to xarray
+        y_pred = convert_to_xarray(y_pred, norm, varname_tar, ground_truth.coords, ground_truth.dims, True)
+
+        # calculate score
+        func_logger.info(f"Calculate score for permuted samples of {var}...")
+        score_engine = Scores(y_pred, ground_truth, dims=ground_truth.dims[1::])
+        score_all.loc[{"predictor": var}] = score_engine(score_name)
+
+        #free_mem([da_copy, da_permut, tfds_test, y_pred, score_engine])
+
+    return score_all
+
+
+
+
+                                    
+
+
+
+
+
+
 class Scores:
     """
     Class to calculate scores and skill scores.
@@ -224,16 +554,121 @@ class Scores:
     def avg_dims(self, dims):
         if dims is None:
             self.avg_dims = self.data_dims
-            print("Scores will be averaged across all data dimensions.")
+            # print("Scores will be averaged across all data dimensions.")
         else:
             dim_stat = [avg_dim in self.data_dims for avg_dim in dims]
             if not all(dim_stat):
-                ind_bad = [i for i, x in enumerate(dim_stat) if x]
-                raise ValueError("The following dimensions for score-averaging are not" +
-                                 "part of the data: {0}".format(", ".join(dims[ind_bad])))
+                ind_bad = [i for i, x in enumerate(dim_stat) if not x]
+                raise ValueError("The following dimensions for score-averaging are not " +
+                                 "part of the data: {0}".format(", ".join(np.array(dims)[ind_bad])))
 
             self._avg_dims = dims
 
+    def get_2x2_event_counts(self, thresh):
+        """
+        Get counts of 2x2 contingency tables
+        :param thres: threshold to define events
+        :return: (a, b, c, d)-tuple of 2x2 contingency table
+        """
+        a = ((self.data_fcst >= thresh) & (self.data_ref >= thresh)).sum(dim=self.avg_dims)
+        b = ((self.data_fcst >= thresh) & (self.data_ref < thresh)).sum(dim=self.avg_dims)
+        c = ((self.data_fcst < thresh) & (self.data_ref >= thresh)).sum(dim=self.avg_dims)
+        d = ((self.data_fcst < thresh) & (self.data_ref < thresh)).sum(dim=self.avg_dims)
+
+        return a, b, c, d
+
+    def calc_ets(self, thresh=0.1):
+        """
+        Calculates Equitable Threat Score (ETS) on data.
+        :param thres: threshold to define events
+        :return: ets-values
+        """
+        a, b, c, d = self.get_2x2_event_counts(thresh)
+        n = a + b + c + d
+        ar = (a + b)*(a + c)/n      # random reference forecast
+        
+        denom = (a + b + c - ar)
+
+        ets = (a - ar)/denom
+        ets = ets.where(denom > 0, np.nan)
+
+        return ets
+    
+    def calc_fbi(self, thresh=0.1):
+        """
+        Calculates Frequency bias (FBI) on data.
+        :param thres: threshold to define events
+        :return: fbi-values
+        """
+        a, b, c, d = self.get_2x2_event_counts(thresh)
+
+        denom = a+c
+        fbi = (a + b)/denom
+
+        fbi = fbi.where(denom > 0, np.nan)
+
+        return fbi
+    
+    def calc_pss(self, thresh=0.1):
+        """
+        Calculates Peirce Skill Score (PSS) on data.
+        :param thres: threshold to define events
+        :return: pss-values
+        """
+        a, b, c, d = self.get_2x2_event_counts(thresh)      
+
+        denom = (a + c)*(b + d)
+        pss = (a*d - b*c)/denom
+
+        pss = pss.where(denom > 0, np.nan)
+
+        return pss   
+
+    def calc_l1(self, **kwargs):
+        """
+        Calculate the L1 error norm of forecast data w.r.t. reference data.
+        L1 will be divided by the number of samples along the average dimensions.
+        Similar to MAE, but provides just a number divided by number of samples along average dimensions.
+        :return: L1-error 
+        """
+        if kwargs:
+            print("Passed keyword arguments to calc_l1 are without effect.")   
+
+        l1 = np.sum(np.abs(self.data_fcst - self.data_ref))
+
+        len_dims = np.array([self.data_fcst.sizes[dim] for dim in self.avg_dims])
+        l1 /= np.prod(len_dims)
+
+        return l1
+    
+    def calc_l2(self, **kwargs):
+        """
+        Calculate the L2 error norm of forecast data w.r.t. reference data.
+        Similar to RMSE, but provides just a number divided by number of samples along average dimensions.
+        :return: L2-error 
+        """
+        if kwargs:
+            print("Passed keyword arguments to calc_l2 are without effect.")   
+
+        l2 = np.sum(np.square(self.data_fcst - self.data_ref))
+
+        len_dims = np.array([self.data_fcst.sizes[dim] for dim in self.avg_dims])
+        l2 /= np.prod(len_dims)
+
+        return l2
+    
+    def calc_mae(self, **kwargs):
+        """
+        Calculate mean absolute error (MAE) of forecast data w.r.t. reference data
+        :return: MAE averaged over provided dimensions
+        """
+        if kwargs:
+            print("Passed keyword arguments to calc_mae are without effect.")   
+
+        mae = np.abs(self.data_fcst - self.data_ref).mean(dim=self.avg_dims)
+
+        return mae
+
     def calc_mse(self, **kwargs):
         """
         Calculate mse of forecast data w.r.t. reference data
@@ -251,6 +686,25 @@ class Scores:
         rmse = np.sqrt(self.calc_mse(**kwargs))
 
         return rmse
+    
+    def calc_acc(self, clim_mean: xr.DataArray, spatial_dims: List = ["lat", "lon"]):
+        """
+        Calculate anomaly correlation coefficient (ACC).
+        :param clim_mean: climatological mean of the data
+        :param spatial_dims: names of spatial dimensions over which ACC are calculated. 
+                             Note: No averaging is possible over these dimensions.
+        :return acc: Averaged ACC (except over spatial_dims)
+        """
+
+        fcst_ano, obs_ano = self.data_fcst - clim_mean, self.data_ref - clim_mean
+
+        acc = (fcst_ano*obs_ano).sum(spatial_dims)/np.sqrt(fcst_ano.sum(spatial_dims)*obs_ano.sum(spatial_dims))
+
+        mean_dims = [x for x in self.avg_dims if x not in spatial_dims]
+        if len(mean_dims) > 0:
+            acc = acc.mean(mean_dims)
+
+        return acc
 
     def calc_bias(self, **kwargs):
 
@@ -297,6 +751,131 @@ class Scores:
             ratio_spat_variability = ratio_spat_variability.mean(dim=avg_dims)
 
         return ratio_spat_variability
+    
+    def calc_iqd(self, xnodes=None, nh=0, lfilter_zero=True):
+        """
+        Calculates squared integrated distance between simulation and observational data.
+        Method: Retrieves the empirical CDF, calculates CDF(xnodes) for both data sets and
+                then uses the squared differences at xnodes for trapezodial integration.
+                Note, that xnodes should be selected in a way, that CDF(xvalues) increases
+                more or less continuously by ~0.01 - 0.05 for increasing xnodes-elements
+                to ensure accurate integration.
+
+        :param data_simu : 1D-array of simulation data
+        :param data_obs  : 1D-array of (corresponding) observational data
+        :param xnodes    : x-values used as nodes for integration (optional, automatically set if not given
+                           according to stochastic properties of precipitation data)
+        :param nh        : accumulation period (affects setting of xnodes)
+        :param lfilter_zero: Flag to filter out zero values from CDF calculation
+        :return: integrated quadrated distance between CDF of data_simu and data_obs
+        """
+
+        data_simu = self.data_fcst.values.flatten()
+        data_obs = self.data_ref.values.flatten() 
+
+        # Scarlet: because data_simu and data_obs are flattened anyway this is not needed
+        #if np.ndim(data_simu) != 1 or np.ndim(data_obs) != 1:
+        #    raise ValueError("Input data arrays must be 1D-arrays.")
+
+        if xnodes is None:
+            if nh == 1:
+                xnodes = [0., 0.005, 0.01, 0.015, 0.025, 0.04, 0.06, 0.08, 0.1, 0.13, 0.16, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6,
+                          0.8, 1., 1.25, 1.5, 1.8, 2.4, 3., 3.75, 4.5, 5.25, 6., 7., 9., 12., 20., 30., 50.]
+            elif 1 < nh <= 6:
+                ### obtained manually based on observational data between May and July 2017
+                ### except for the first step and the highest node-values,
+                ### CDF is increased by 0.03 - 0.05 with every step ensuring accurate integration
+                xnodes = [0.00, 0.005, 0.01, 0.02, 0.04, 0.07, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1.15, 1.5, 1.9, 2.4,
+                         3., 4., 5., 6., 7.5, 10., 15., 25., 40., 60.]
+            else:
+                xnodes = [0.00, 0.01, 0.02, 0.035, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1., 1.3, 1.7, 2.1, 2.5,
+                          3., 3.5, 4., 4.75, 5.5, 6.3, 7.1, 8., 9., 10., 12.5, 15., 20., 25., 35., 50., 70., 100.]
+
+        data_simu_filt = data_simu[~np.isnan(data_simu)]
+        data_obs_filt = data_obs[~np.isnan(data_obs)]
+        if lfilter_zero:
+            data_simu_filt = np.sort(data_simu_filt[data_simu_filt > 0.])
+            data_obs_filt  = np.sort(data_obs_filt[data_obs_filt > 0.])
+        else:
+            data_simu_filt = np.sort(data_simu_filt)
+            data_obs_filt  = np.sort(data_obs_filt)
+
+        nd_points_simu = np.shape(data_simu_filt)[0]
+        nd_points_obs  = np.shape(data_obs_filt)[0]
+
+        prob_simu = 1. * np.arange(nd_points_simu)/ (nd_points_simu - 1)
+        prob_obs  = 1. * np.arange(nd_points_obs)/ (nd_points_obs -1)
+
+        cdf_simu = get_cdf_of_x(data_simu_filt,prob_simu)
+        cdf_obs  = get_cdf_of_x(data_obs_filt,prob_obs)
+
+        yvals_simu = cdf_simu(xnodes)
+        yvals_obs  = cdf_obs(xnodes)
+
+        if yvals_obs[-1] < 0.999:
+            print("CDF of last xnodes {0:5.2f} for observation data is smaller than 99.9%." +
+                  "Consider setting xnodes manually!")
+
+        if yvals_simu[-1] < 0.999:
+            print("CDF of last xnodes {0:5.2f} for simulation data is smaller than 99.9%." +
+                  "Consider setting xnodes manually!")
+
+        # finally, perform trapezodial integration
+        return np.trapz(np.square(yvals_obs - yvals_simu), xnodes)
+
+    def calc_seeps(self, seeps_weights: xr.DataArray, t1: xr.DataArray, t3: xr.DataArray, spatial_dims: List):
+        """
+        Calculates stable equitable error in probabiliyt space (SEEPS), see Rodwell et al., 2011
+        :param seeps_weights: SEEPS-parameter matrix to weight contingency table elements
+        :param t1: threshold for light precipitation events
+        :param t3: threshold for strong precipitation events
+        :param spatial_dims: list/name of spatial dimensions of the data
+        :return seeps skill score (i.e. 1-SEEPS)
+        """
+
+        def seeps(data_ref, data_fcst, thr_light, thr_heavy, seeps_weights):
+            ob_ind = (data_ref > thr_light).astype(int) + (data_ref >= thr_heavy).astype(int)
+            fc_ind = (data_fcst > thr_light).astype(int) + (data_fcst >= thr_heavy).astype(int)
+            indices = fc_ind * 3 + ob_ind  # index of each data point in their local 3x3 matrices
+            seeps_val = seeps_weights[indices, np.arange(len(indices))]  # pick the right weight for each data point
+            
+            return 1.-seeps_val
+        
+        if self.data_fcst.ndim == 3:
+            assert len(spatial_dims) == 2, f"Provide two spatial dimensions for three-dimensional data."
+            data_fcst, data_ref = self.data_fcst.stack({"xy": spatial_dims}), self.data_ref.stack({"xy": spatial_dims})
+            seeps_weights = seeps_weights.stack({"xy": spatial_dims})
+            t3 = t3.stack({"xy": spatial_dims})
+            lstack = True
+        elif self.data_fcst.ndim == 2:
+            data_fcst, data_ref = self.data_fcst, self.data_ref
+            lstack = False
+        else:
+            raise ValueError(f"Data must be a two-or-three-dimensional array.")
+
+        # check dimensioning of data
+        assert data_fcst.ndim <= 2, f"Data must be one- or two-dimensional, but has {data_fcst.ndim} dimensions. Check if stacking with spatial_dims may help." 
+
+        if data_fcst.ndim == 1:
+            seeps_values_all = seeps(data_ref, data_fcst, t1.values, t3, seeps_weights)
+        else:
+            data_fcst, data_ref = data_fcst.transpose(..., "xy"), data_ref.transpose(..., "xy")
+            seeps_values_all = xr.full_like(data_fcst, np.nan)
+            seeps_values_all.name = "seeps"
+            for it in range(data_ref.shape[0]):
+                data_fcst_now, data_ref_now = data_fcst[it, ...], data_ref[it, ...]
+                # in case of missing data, skip computation
+                if np.all(np.isnan(data_fcst_now)) or np.all(np.isnan(data_ref_now)):
+                    continue
+
+                seeps_values_all[it,...] = seeps(data_ref_now, data_fcst_now, t1.values, t3, seeps_weights.values)
+
+        if lstack:
+            seeps_values_all = seeps_values_all.unstack()
+
+        seeps_values = seeps_values_all.mean(dim=self.avg_dims)
+
+        return seeps_values
 
     @staticmethod
     def calc_geo_spatial_diff(scalar_field: xr.DataArray, order: int = 1, r_e: float = 6371.e3, dom_avg: bool = True):
diff --git a/downscaling_ap5/preprocess/add_invariant_data_atmorep.ipynb b/downscaling_ap5/preprocess/add_invariant_data_atmorep.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d8f6d8ad0f2cf6906c612b586556115317dbae15
--- /dev/null
+++ b/downscaling_ap5/preprocess/add_invariant_data_atmorep.ipynb
@@ -0,0 +1,188 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "ceeb7395-688e-43ee-9aa4-55818a592555",
+   "metadata": {
+    "tags": []
+   },
+   "source": [
+    "# Notebook to add constant variables to competing AtmoRep downscaling data\n",
+    "\n",
+    "This Notebook processes the files generated with `preprocees_data_atmorep.sh` to add the surface topography from ERA5 and COSMO REA6 data which both constitute invariant fields, but have to be expanded to include a time-dimension."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7426cf19-7ff9-4358-8615-57164e7c7f9e",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "! pip install findlibs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a2b2f8f1-7112-42ac-bf48-8018eb7e4b1d",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import sys\n",
+    "import glob\n",
+    "from tqdm import tqdm\n",
+    "\n",
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "import xarray as xr\n",
+    "import cfgrib"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8b6adba5-71f6-4ab3-b866-c76fb02cedeb",
+   "metadata": {
+    "tags": []
+   },
+   "source": [
+    "Parameters:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c2220e62-6873-4fb4-86e9-2c40a3b519ac",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "data_dir=\"/p/scratch/deepacf/maelstrom/maelstrom_data/ap5/competing_atmorep/\"\n",
+    "invar_file_era5 = \"/p/scratch/deepacf/maelstrom/maelstrom_data/ap5/competing_atmorep/reanalysis_orography.nc\"\n",
+    "invar_file_crea6 = \"/p/scratch/atmo-rep/data/cosmo_rea6/static/cosmo_rea6_orography.nc\""
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "afb9b245-9ecf-4c91-8103-de4811240e0e",
+   "metadata": {
+    "tags": []
+   },
+   "source": [
+    "The file 'invar_file_era5' has been generated with the following CDO-command:\n",
+    "``` \n",
+    "cdo --reduce_dim -t ecmwf -f nc copy -remapbil,~/downscaling_maelstrom/downscaling_jsc_repo/downscaling_ap5/grid_des/crea6_reg_grid reanalysis_orography.grib reanalysis_orography.nc\n",
+    "``` \n",
+    "where the original grib-file was obatined from AtmoRep (```/p/scratch/atmo-rep/data/era5/static```)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "cacd266c-6e0e-476d-9805-f4fac289e3cf",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "file_list = glob.glob(os.path.join(data_dir, \"downscaling_atmorep*.nc\"))\n",
+    "\n",
+    "if len(file_list) == 0:\n",
+    "    raise FileNotFoundError(f\"Could not find any datafiles under '{data_dir}'...\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2c9bb25b-4930-47db-8f9d-a6e1cbc59571",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "ds_invar_era5 = xr.open_dataset(invar_file_era5)\n",
+    "ds_invar_crea6 = xr.open_dataset(invar_file_crea6).sel({\"lat\": ds_invar_era5[\"lat\"], \"lon\": ds_invar_era5[\"lon\"]})\n",
+    "ds_invar_crea6 = ds_invar_crea6.drop_vars(\"FR_LAND\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3ebb63ef-9dd0-49e1-be15-560f87c51166",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "for f in tqdm(file_list):\n",
+    "    # read current file\n",
+    "    print(f\"Process data-file '{f}'...\")\n",
+    "    ds_now = xr.open_dataset(f)\n",
+    "    var_list = list(ds_now.data_vars)\n",
+    "    lchange = False\n",
+    "    \n",
+    "    if \"z_in\" not in var_list:\n",
+    "        print(f\"Add surface topography from ERA5...\")\n",
+    "        dst = ds_invar_era5.expand_dims(time=ds_now[\"time\"])\n",
+    "        dst = dst.rename({\"Z\": \"z_in\"})\n",
+    "    \n",
+    "        ds_all = xr.merge([ds_now, dst])\n",
+    "        lchange = True\n",
+    "        \n",
+    "    if \"hsurf_tar\" not in var_list:\n",
+    "        print(f\"Add surface topography from CREA6...\")\n",
+    "        dst = ds_invar_crea6.expand_dims(time=ds_now[\"time\"])\n",
+    "        dst = dst.rename({\"z\": \"hsurf_tar\"})\n",
+    "    \n",
+    "        ds_all = xr.merge([ds_all , dst])\n",
+    "        lchange = True\n",
+    "        \n",
+    "    if \"t2m_ml0_tar\" in var_list:\n",
+    "        ds_all = ds_all.rename({\"t2m_ml0_tar\": \"t2m_tar\"})\n",
+    "        lchange = True\n",
+    "    \n",
+    "    if lchange:\n",
+    "        print(f\"Write modified dataset back to '{f}'...\")\n",
+    "        ds_all.to_netcdf(f.replace(\".nc\", \"_new.nc\"))\n",
+    "    else:\n",
+    "        print(f\"No changes to data from '{f}' applied. Continue...\")\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "91ceb153-5900-428f-a097-fd9aab19c669",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "langguth1_downscaling_kernel",
+   "language": "python",
+   "name": "langguth1_downscaling_kernel"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/downscaling_ap5/preprocess/download_era5_data.py b/downscaling_ap5/preprocess/download_era5_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..caffe444b5b2ba30d3c1397a2f7cfdc5ba441dbd
--- /dev/null
+++ b/downscaling_ap5/preprocess/download_era5_data.py
@@ -0,0 +1,288 @@
+# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
+#
+# SPDX-License-Identifier: MIT
+
+"""
+Script to download ERA5 data from the CDS API.
+"""
+
+__author__ = "Michael Langguth"
+__email__ = "m.langguth@fz-juelich.de"
+__date__ = "2023-11-21"
+__update__ = "2023-08-22"
+
+# import modules
+import os, sys
+import logging
+import cdsapi
+import numpy as np
+import pandas as pd
+from multiprocessing import Pool
+from other_utils import to_list
+
+# get logger
+logger_module_name = f"main_download_era5.{__name__}"
+module_logger = logging.getLogger(logger_module_name)
+
+# known request parameters
+knwon_req_keys = ["ml", "sfc"]
+
+
+class ERA5_Data_Loader(object):
+    """
+    Class to download ERA5 data from the CDS API.
+    """
+
+    knwon_req_keys = ["ml", "sfc"]
+    allowed_formats = ["netcdf", "grib"]
+    # defaults 
+    area = [75, -45, 20, 65]
+    month_start = 1
+    month_end = 12
+
+
+    def __init__(self, nworkers) -> None:
+
+        # get local logger
+        func_logger = logging.getLogger(f"{logger_module_name}.{self.__init__.__name__}")
+        
+        self.nworkers = nworkers
+        try: 
+            self.cds = cdsapi.Client()
+        except Exception as e:
+            func_logger.error(f"Could not initialize CDS API client: {e} \n" + \
+                              "Please follow the instructions at https://cds.climate.copernicus.eu/api-how-to to install the CDS API.")                                                 
+            raise e
+        
+    def __call__(self, req_dict, data_dir, start, end, format, **kwargs):
+        """
+        Run the requests to download the ERA5 data.
+        :param req_dict: dictionary with data requests
+        :param data_dir: directory where output data files will be stored
+        :param start: start year of data request
+        :param end: end year of data request
+        :param format: format of downloaded data (netcdf or grib)
+        :param kwargs: additional keyword arguments (options: area, month_start, month_end)
+        """
+
+        # get local logger
+        func_logger = logging.getLogger(f"{logger_module_name}.{self.__call__.__name__}")
+
+        # create output directory
+        if not os.path.exists(data_dir):
+            func_logger.info(f"Creating output directory {data_dir}")
+            os.makedirs(data_dir)
+
+        # validate request keys
+        req_list = self.validate_request_keys(req_dict)
+
+        # select data format
+        if format not in self.allowed_formats:
+            func_logger.warning(f"Unknown data format {format}. Using default format netcdf.")
+            format = "netcdf"
+
+        for req_key in req_list:
+            if req_key == "sfc":
+                out = self.download_sfc_data(req_dict[req_key], data_dir, start, end, format, **kwargs)
+            elif req_key == "ml":
+                out = self.download_ml_data(req_dict[req_key], data_dir, start, end, format, **kwargs)
+
+            # check if all requests were successful
+            _ = self.check_request_result(out)
+
+        return out
+    
+
+    def download_sfc_data(self, varlist, data_dir, start, end, format, **kwargs):
+        """
+        Download ERA5 surface data.
+        To obtain the varlist, please refer to the CDS API documentation at https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-single-levels?tab=form
+        :param varlist: list of variables to download
+        :param data_dir: directory where output data files will be stored
+        :param start: start year of data request
+        :param end: end year of data request
+        :param format: format of downloaded data (netcdf or grib)
+        :param kwargs: additional keyword arguments (options: area, month_start, month_end)
+        :return: output of multiprocessing pool
+        """
+        # get local logger
+        func_logger = logging.getLogger(f"{logger_module_name}.{self.download_sfc_data.__name__}")
+
+        # get additional keyword arguments
+        area = kwargs.get("area", self.area)
+        month_start = kwargs.get("month_start", self.month_start)
+        month_end = kwargs.get("month_end", self.month_end)
+
+        # create base request dictionary (None-values will be set dynamically)
+        req_dict_base = {"product_type": "reanalysis", "format": f"{format}", 
+                         "variable": to_list(varlist), 
+                         "day": None, "month": None, "time": [f"{h:02d}" for h in range(24)], "year": None,
+                        "area": area}   
+
+        # initialize multiprocessing pool
+        func_logger.info(f"Downloading ERA5 surface data for variables {', '.join(varlist)} with {self.nworkers} workers.")
+        pool = Pool(self.nworkers)
+
+        # initialize dictionary for request results
+        req_results = {}
+
+        # create data requests for each month
+        for year in range(start, end+1):
+            req_dict = req_dict_base.copy()
+            req_dict["year"] = [f"{year}"]
+            for month in range(month_start, month_end+1):
+                req_dict["month"] = [f"{month:02d}"]
+                # get last day of month
+                last_day = pd.Timestamp(year, month, 1) + pd.offsets.MonthEnd(1)
+                req_dict["day"] = [f"{d:02d}" for d in range(1, last_day.day+1)]
+                fout = f"era5_sfc_{year}-{month:02d}.{format}"
+                
+                func_logger.debug(f"Downloading ERA5 surface data for {year}-{month:02d} to {os.path.join(data_dir, fout)}")
+
+                req_results[fout] = pool.apply_async(self.cds.retrieve, args=("reanalysis-era5-single-levels", req_dict,
+                                                     os.path.join(data_dir, fout)))
+
+        # run and close multiprocessing pool
+        pool.close()
+        pool.join()
+
+        func_logger.info(f"Finished downloading ERA5 surface data.")
+
+        return req_results
+
+    def download_ml_data(self, var_dict, data_dir, start, end, format, **kwargs):
+        """
+        Download ERA5 data for multiple levels.
+        To obtain the varlist, please refer to the CDS API documentation at https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-complete?tab=form
+        :param var_dict: dictionary of variables to download for each level
+        :param data_dir: directory where output data files will be stored
+        :param start: start year of data request
+        :param end: end year of data request
+        :param format: format of downloaded data (netcdf or grib)
+        :param kwargs: additional keyword arguments (options: area, month_start, month_end)
+        :return: output of multiprocessing pool
+        """
+        # get local logger
+        func_logger = logging.getLogger(f"{logger_module_name}.{self.download_ml_data.__name__}")
+
+        # get additional keyword arguments
+        area = kwargs.get("area", self.area)
+        month_start = kwargs.get("month_start", self.month_start)
+        month_end = kwargs.get("month_end", self.month_end)
+
+        vars = var_dict.keys()
+        vars_param = self.translate_mars_vars(vars)
+        # ensure that data is downloaded for all levels (All-together approach -> overhead: additional download of data for levels that are not needed)
+        collector = []
+        _ = [collector.extend(ml_list) for ml_list in var_dict.values()]
+        all_lvls = sorted([str(lvl) for lvl in set(collector)])
+
+        # create base request dictionary (None-values will be set dynamically)
+        req_dict_base = {"class": "ea", "date": None,
+                         "expver": "1",
+                         "levelist": "/".join(all_lvls),
+                         "levtype": "ml",
+                         "param": vars_param,
+                         "stream": "oper",
+                         "time": "00/to/23/by/1",
+                         "type": "an", 
+                         "area": area ,
+                         "grid": "0.25/0.25",}
+
+        # initialize multiprocessing pool
+        func_logger.info(f"Downloading ERA5 model-level data for variables {', '.join(vars)} with {self.nworkers} workers.")
+        pool = Pool(self.nworkers)
+        
+        # initialize dictionary for request results
+        req_results = {}
+
+        # create data requests for each month
+        for year in range(start, end+1):
+            req_dict = req_dict_base.copy()
+            for month in range(month_start, month_end+1):
+                # get last day of month
+                last_day = pd.Timestamp(year, month, 1) + pd.offsets.MonthEnd(1)
+                req_dict["date"] = f"{year}-{month:02d}-01/to/{year}-{month:02d}-{last_day.day}" 
+                fout = f"era5_ml_{year}-{month:02d}.{format}"
+
+                func_logger.debug(f"Downloading ERA5 model-level data for {year}-{month:02d} to {os.path.join(data_dir, fout)}")
+
+                req_results[fout] = pool.apply_async(self.cds.retrieve, args=("reanalysis-era5-complete", req_dict,
+                                                     os.path.join(data_dir, fout)))
+            
+        # run and close multiprocessing pool
+        pool.close()
+        pool.join()
+
+        func_logger.info(f"Finished downloading ERA5 model-level data.")
+
+        return req_results
+    
+    def check_request_result(self, results_dict):
+        """
+        Check if request was successful.
+        :param results_dict: dictionary with request results (returned by download_ml_data and download_sfc_data)
+        """
+        # get local logger
+        func_logger = logging.getLogger(f"{logger_module_name}.{self.check_request_result.__name__}")
+
+        # check if all requests were successful
+        stat = [o.get().__dict__["reply"]["state"] == "completed" for o in results_dict.values()]
+
+        ok = True
+
+        if all(stat):
+            func_logger.info(f"All requests were successful.")
+        else:
+            ok = False
+            bad_req = np.where(np.array(stat) == False)[0]
+            results_arr = np.array(list(results_dict.keys()))
+            func_logger.error(f"The following requests were not successful: {', '.join(results_arr[bad_req])}.")
+
+        return ok
+
+
+    def validate_request_keys(self, req_dict):
+        """
+        Validate request keys in data request file.
+        :param req_dict: dictionary with data requests
+        :return: list of valid request keys (filtered)
+        """
+        # get local logger
+        func_logger = logging.getLogger(f"{logger_module_name}.{self.validate_request_keys.__name__}")
+
+        # create list of requests
+        req_list = []
+        for req_key in req_dict.keys():
+            if req_key in self.knwon_req_keys:
+                req_list.append(req_key)
+            else:
+                func_logger.warning(f"Unknown request key {req_key} in data request file. Skipping this request.")
+
+        return req_list
+
+    def translate_mars_vars(self, vars):
+        """
+        Translate variable names to MARS parameter names.
+        :param vars: list of variable names
+        :return: list of MARS parameter names
+        """
+
+        # create dictionary with variable names and corresponding MARS parameter names
+        var_dict = {"z": "129", "t": "130", "u": "131", "v": "132", "w": "135", "q": "133", 
+                    "temperature": "130", "specific_humidity": "133", "geopotential": "129",
+                    "vertical_velocity": "135", "u_component_of_wind": "131", "v_component_of_wind": "132",
+                    "vorticity": "138", "divergence": "139", "logarithm_of_surface_pressure": "152", 
+                    "fraction_of_cloud_cover": "164", "specific_cloud_liquid_water_content": "246",
+                    "specific_cloud_ice_water_content": "247", "specific_rain_water_content": "248",
+                    "specific_snow_water_content": "249", }
+
+        # translate variable names
+        vars_param = [var_dict[var] for var in vars]
+
+        return vars_param
+
+
+
+
+
diff --git a/downscaling_ap5/preprocess/preprocess_data_atmorep.sh b/downscaling_ap5/preprocess/preprocess_data_atmorep.sh
new file mode 100755
index 0000000000000000000000000000000000000000..1b406a6d8b8f75b594970b76d1e347015c58e2d3
--- /dev/null
+++ b/downscaling_ap5/preprocess/preprocess_data_atmorep.sh
@@ -0,0 +1,101 @@
+#!/bin/bash
+
+##############################################################################
+# Script to process the ERA5 and COSMO REA6 data that has also been used     #
+# in the AtMoRep project. To be processable for the data loader in MAELSTROM,#
+# the input and target data must be provided in monthly netCDF files.        #
+# Furthermore, input and target variables must be denoted with _in and _tar, #
+# respectively.                                                              #
+##############################################################################
+
+# author: Michael Langguth
+# date: 2023-08-16
+# update: 2023-08-16
+
+# parameters
+
+era5_basedir=/p/scratch/atmo-rep/data/era5/ml_levels/
+crea6_basedir=/p/scratch/atmo-rep/data/cosmo_rea6/ml_levels/
+output_dir=/p/scratch/deepacf/maelstrom/maelstrom_data/ap5/competing_atmorep/
+crea6_gdes=../grid_des/crea6_reg_grid
+
+era5_vars=("t")
+era5_vars_full=("temperature")
+crea6_vars=("t2m")
+crea6_vars_full=("t_2m")
+
+ml_lvl_era5=( 96 105 114 123 137 )
+ml_lvl_crea6=( 0 )
+
+year_start=1995
+year_end=2018
+
+# main
+
+echo "Loading required modules..."
+ml purge
+ml Stages/2022 GCC/11.2.0  OpenMPI/4.1.2 NCO/5.0.3 CDO/2.0.2
+
+tmp_dir=${output_dir}/tmp
+
+# create output directory
+if [ ! -d $output_dir ]; then
+    mkdir -p $output_dir
+fi
+
+if [ ! -d $tmp_dir ]; then
+    mkdir -p $tmp_dir
+fi
+
+# loop over years and months
+for yr in $(eval echo "{$year_start..$year_end}"); do 
+    for mm in {01..12}; do
+        echo "Processing ERA5 data for ${yr}-${mm}..."
+        for ivar in "${!era5_vars[@]}"; do
+	    echo "Processing variable ${era5_vars[ivar]}  from ERA5..."	
+            for ml_lvl in ${ml_lvl_era5[@]}; do 
+                echo "Processing data for level ${ml_lvl}"
+		# get file name
+                era5_file=${era5_basedir}/${ml_lvl}/"${era5_vars_full[ivar]}"/reanalysis_"${era5_vars_full[ivar]}"_y${yr}_m${mm}_ml${ml_lvl}.grib
+                tmp_file1=${tmp_dir}/era5_${era5_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}_tmp1.nc 
+                tmp_file2=${tmp_dir}/era5_${era5_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}_tmp2.nc
+                tmp_era5=${tmp_dir}/era5_${era5_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}_tmp3.nc
+                # convert to netCDF and slice to region of interest
+                cdo -f nc copy -sellonlatbox,-1.5,25.75,42.25,56 ${era5_file} ${tmp_file1}
+                cdo remapbil,${crea6_gdes} ${tmp_file1} ${tmp_file2}
+                # rename variable
+                cdo --reduce_dim chname,${era5_vars[ivar]},${era5_vars[ivar]}_ml${ml_lvl}_in -selvar,${era5_vars[ivar]} ${tmp_file2} ${tmp_era5} 
+                # clean-up
+                rm ${tmp_file1} ${tmp_file2}
+            done
+        done
+
+	era5_file_now=${tmp_dir}/era5_y${yr}_m${mm}.nc
+        cdo merge ${tmp_dir}/*.nc ${era5_file_now}
+	# clean-up
+	rm ${tmp_dir}/era5*tmp*.nc
+
+        echo "Processing COSMO-REA6 data for ${yr}-${mm}..."
+        for ivar in ${!crea6_vars[@]}; do 
+            for ml_lvl in ${ml_lvl_crea6}; do
+                crea6_file=${crea6_basedir}/${ml_lvl}/${crea6_vars[ivar]}/cosmo_rea6_${crea6_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}.nc
+                tmp_crea6=${tmp_dir}/crea6_${crea6_vars[ivar]}_y${yr}_m${mm}_ml${ml_lvl}_tmp_crea6.nc
+                cdo --reduce_dim -chname,${crea6_vars_full[ivar]},${crea6_vars[ivar]}_ml${ml_lvl}_tar -sellonlatbox,-1.25,25.6875,42.3125,55.75 -selvar,${crea6_vars_full[ivar]} ${crea6_file} ${tmp_crea6}
+            done
+        done
+
+	crea6_file_now=${tmp_dir}/crea6_y${yr}_m${mm}.nc
+        cdo merge ${tmp_dir}/*tmp_crea6.nc ${crea6_file_now}
+
+        cdo merge ${crea6_file_now} ${era5_file_now} ${output_dir}/downscaling_atmorep_train_${yr}-${mm}.nc;
+	
+	# clean-up
+	rm ${tmp_dir}/*.nc
+    done
+done
+
+
+
+
+
+
diff --git a/downscaling_ap5/preprocess/preprocess_data_era5_to_crea6.py b/downscaling_ap5/preprocess/preprocess_data_era5_to_crea6.py
old mode 100644
new mode 100755
diff --git a/downscaling_ap5/utils/other_utils.py b/downscaling_ap5/utils/other_utils.py
index 2853df1e88789774b1821e482a81df47f9443ab5..649f2de8db122d8c95d4f1ab6888aa71a12a1c22 100644
--- a/downscaling_ap5/utils/other_utils.py
+++ b/downscaling_ap5/utils/other_utils.py
@@ -12,24 +12,27 @@ Some auxiliary functions for the project:
     * subset_files_on_date
     * extract_date
     * ensure_datetime
+    * doy_to_mo
     * last_day_of_month
     * flatten
     * remove_files
     * check_str_in_list
+    * shape_from_str
     * find_closest_divisor
-    * free_mem
+#    * free_mem
     * print_gpu_usage
     * print_cpu_usage
     * get_memory_usage
     * get_max_memory_usage
     * copy_filelist
+    * merge_dicts
 """
 # doc-string
 
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-01-20"
-__update__ = "2023-03-17"
+__update__ = "2023-11-27"
 
 import os
 import gc
@@ -170,6 +173,18 @@ def ensure_datetime(date):
     return date_dt
 
 
+def doy_to_mo(day_of_year: int, year: int):
+    """
+    Converts day of year to year-month datetime object (e.g. in_day = 2, year = 2017 yields January 2017).
+    From AtmoRep-project: https://isggit.cs.uni-magdeburg.de/atmorep/atmorep/
+    :param day_of_year: day of year
+    :param year: corresponding year (to reflect leap years)
+    :return year-month datetime object
+    """
+    date_month = pd.to_datetime(year * 1000 + day_of_year, format='%Y%j')
+    return date_month
+
+
 def last_day_of_month(any_day):
     """
     Returns the last day of a month
@@ -249,6 +264,19 @@ def check_str_in_list(list_in: List, str2check: str_or_List, labort: bool = True
         return stat, []
 
 
+def shape_from_str(fname):
+    """
+    Retrieves shapes from AtmoRep output-filenames.
+    From AtmoRep-project: https://isggit.cs.uni-magdeburg.de/atmorep/atmorep/
+    :param fname: filename of AtmoRep output-file
+    :return shapes inferred from AtmoRep output-file
+    """
+    shapes_str = fname.replace("_phys.dat", ".dat").split(".")[-2].split("_s_")[-1].split("_") #remove .ext and split
+    shapes = [int(i) for i in shapes_str]
+    return shapes
+
+    
+    
 def find_closest_divisor(n1, div):
     """
     Function to find closest divisor for a given number with respect to a target value
@@ -275,16 +303,18 @@ def find_closest_divisor(n1, div):
         return all_divs[i]
 
 
-def free_mem(var_list: List):
-    """
-    Delete all variables in var_list and release memory
-    :param var_list: list of variables to be deleted
-    """
-    var_list = to_list(var_list)
-    for var in var_list:
-        del var
-
-    gc.collect()
+#def free_mem(var_list: List):
+# *** This was found to be in effective, cf. michael_issue085-memory_footprint ***
+#
+#    """
+#    Delete all variables in var_list and release memory
+#    :param var_list: list of variables to be deleted
+#    """
+#    var_list = to_list(var_list)
+#    for var in var_list:
+#        del var
+#
+#    gc.collect()
 
 # The following auxiliary methods have been adapted from MAELSTROM AP3,
 # see https://git.ecmwf.int/projects/MLFET/repos/maelstrom-radiation/browse/climetlab_maelstrom_radiation/benchmarks/
@@ -365,4 +395,23 @@ def copy_filelist(file_list: List, dest_dir: str, file_list_dest: List = None ,l
             if labort:
                 raise FileNotFoundError(f"Could not find file '{f}'. Error will be raised.")
             else:
-                print(f"WARNING: Could not find file '{f}'.")
\ No newline at end of file
+                print(f"WARNING: Could not find file '{f}'.")
+
+def merge_dicts(default_dict, user_dict):
+    """
+    Recursively merge two dictionaries, ensuring that all default keys are set.
+    """
+    merged_dict = default_dict.copy()
+
+    for key, value in user_dict.items():
+        if isinstance(value, dict) and key in merged_dict and isinstance(merged_dict[key], dict):
+            # If the value is a dictionary and the key exists in both dictionaries,
+            # recursively merge the dictionaries.
+            merged_dict[key] = merge_dicts(merged_dict[key], value)
+        else:
+            # Otherwise, set the value in the merged dictionary.
+            assert isinstance(value, type(merged_dict[key])), \
+                f"Type mismatch for key '{key}': {type(value)} != {type(merged_dict[key])}"
+            merged_dict[key] = value
+
+    return merged_dict
diff --git a/downscaling_ap5/utils/read_atmorep_data.py b/downscaling_ap5/utils/read_atmorep_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab8646131030f687bbf33eaa574dfd6eb19e3d1c
--- /dev/null
+++ b/downscaling_ap5/utils/read_atmorep_data.py
@@ -0,0 +1,761 @@
+# SPDX-FileCopyrightText: 2023 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
+#
+# SPDX-License-Identifier: MIT
+
+__author__ = "Michael Langguth"
+__email__ = "m.langguth@fz-juelich.de"
+__date__ = "2023-05-21"
+__update__ = "2023-08-15"
+
+# import modules
+import os, sys
+sys.path.append("./")
+import glob
+import json
+from typing import List, Tuple
+from time import time as timer
+from tqdm import tqdm
+
+import pandas as pd
+import xarray as xr
+import numpy as np
+
+from other_utils import doy_to_mo, shape_from_str
+
+# the class to handle AtmoRep data
+class HandleAtmoRepData(object):
+    """
+    Handle outout data of AtmoRep.
+    TO-DO:
+        - get dx-info from token-config
+        - add ibatch-information to coordinates of masked data
+        - sanity check on data (partly done)
+    """
+    known_data_types = ["source", "prediction", "target", "ensembles"]
+
+    # offsets to get reconstruct correct timestamps
+    days_offset = 1
+    hours_offset = 0
+    
+    def __init__(self, model_id: str, dx_in: float, atmorep_dir: str = "/p/scratch/atmo-rep/results/", 
+                 in_dir: str = "/p/scratch/atmo-rep/data/era5/ml_levels/",
+                 target_type: str = "fields_prediction", dx_tar: float = None,
+                 tar_dir: str = None, epsilon: float = 0.001):
+        """
+        :param model_id: ID of Atmorep-run to process
+        :param dx_in: grid spacing of input data
+        :param atmorep_dir: Base-directory where AtmoRep-output is located
+        :param in_dir: Base-directory for input data used to train AtmoRep (to get normalization/correction-data)
+        :param target_type: Either fields_prediction or fields_target
+        :param dx_tar: grid spacing of target data
+        :param tar_dir: Base-directory for target data used to train AtmoRep (to get normalization/correction-data)
+        :param epsilon: epsilon paramter for log transformation (if applied at all)
+        """
+        self.model_id = model_id if model_id.startswith("id") else f"id{model_id}"
+        self.datadir = os.path.join(atmorep_dir, self.model_id)
+        self.datadir_input = in_dir
+        self.datadir_target = self.datadir_input if tar_dir is None else tar_dir
+        self.target_type = target_type
+        self.dx_in = dx_in
+        self.dx_tar = self.dx_in if dx_tar is None else dx_tar
+        self.epsilon = epsilon
+        self.config_file, self.config = self._get_config()
+        self.input_variables = self._get_invars()
+        self.target_variables = self._get_tarvars()
+        
+        self.input_token_config = self.get_input_token_config()
+        self.target_token_config = self.get_target_token_config() 
+        
+    def _get_config(self) -> Tuple[str, dict]:
+        """
+        Get configuration dictionary of trained AtmoRep-model.
+        """
+        config_jsf = os.path.join(self.datadir, f"model_{self.model_id}.json")
+        with open(config_jsf) as json_file:
+            config = json.load(json_file)
+        return config_jsf, config
+    
+    def _get_invars(self) -> List:
+        """
+        Get list of input variables of trained AtmoRep-model.
+        """
+        return [var_list[0] for var_list in self.config["fields"]]
+        #return list(np.asarray(self.config["fields"], dtype=object)[:, 0])
+    
+    def _get_tarvars(self) -> List:
+        """
+        Get list of target variables of trained AtmoRep-model.
+        """
+        return [var_list[0] for var_list in self.config[self.target_type]]
+        #return list(np.asarray(self.config[self.target_type])[:, 0])
+    
+    def _get_token_config(self, key) -> dict:
+        """
+        Generic function to retrieve token configuration.
+        :param key: key-string from which token-info is deducable
+        :return dictionary of token info
+        """
+        token_config_keys = ["general_config", "vlevel", "num_tokens", "token_shape", "bert_parameters"]
+        token_config = {var[0]: None for var in self.config[key]}        
+        for i, var in enumerate(token_config):
+            len_config = len(self.config[key][i])
+            token_config[var] =  {config_key: self.config[key][i][j+1] for j, config_key in enumerate(token_config_keys)}
+            if len_config >= 7:
+                token_config[var]["norm_type"] = self.config[key][i][6]
+                if len_config >= 8:
+                    if isinstance(self.config[key][i][7][-1], bool):    
+                        token_config[var]["log_transform"] = self.config[key][i][7][-1]
+                    else: 
+                        token_config[var]["log_transform"] = False
+            else: # default setting for normalization type
+                token_config[var]["norm_type"] = "global"
+        
+        return token_config
+    
+    def get_input_token_config(self) -> dict:
+        """
+        Get input token configuration
+        """
+        return self._get_token_config("fields")
+        
+    def get_target_token_config(self) -> dict:
+        """
+        Retrieve token configuration of output/target data.
+        Note that the token configuration is the same as the input data as long as target_fields is unset.
+        """
+        if self.target_type in ["target_fields", "fields_targets"]:
+            return self._get_token_config(self.target_type)   
+        else:
+            return self.input_token_config
+    
+    def _get_token_info(self, token_type: str, rank: int = 0, epoch: int = 0, batch: int = 0, varname: str = None, lmasked: bool = True):  
+        """
+        Retrieve token information. 
+        Note: For BERT-training (i.e. lmasked == True), the tokens are scattered in time and space. 
+              Thus, the token information will be returned as an index-based array, whereas the token information
+              is returned in a structured manner (i.e. shaped as (nbatch, nvlevel, nt, ny, nx) where nt, ny, nx 
+              represent the number of tokens in (time, lat, lon)-dimension.
+        :param token_type: Type of token for which info should be retrieved (either 'input', 'target', 'ensemble' or 'prediction')
+        :param rank: rank (of job) that has created the requested token information file
+        :param epoch: training epoch of requested token information file
+        :param batch: batch of requested token information
+        :param varname: name of variable for which token info is requested
+        :param lmasked: flag if masking was applied on token (for token_type 'target' or 'prediction' only)
+        """
+        if self.dx_tar != self.dx_in:   # in case of different grid spacing in target and source data, target specific token info files are present
+            add_str_tar = "target"
+            # ML: Required due to inconsistency in naming convetion for files providing token info
+            #     and masked token indices (see below)
+            add_str_tar2 = "targets" 
+        else:
+            add_str_tar, add_str_tar2 = "", ""
+        var_aux = varname
+        if token_type == self.known_data_types[0]:
+            var_aux = self.input_variables[0] if var_aux is None else var_aux      
+            fpatt = os.path.join(self.datadir, f"*rank{rank}_epoch{epoch:05d}_batch{batch:05d}_token_infos_{var_aux}*.dat")
+            token_dict = self.input_token_config[var_aux]
+        elif token_type in self.known_data_types[1:]:
+            var_aux = self.target_variables[0] if var_aux is None else var_aux 
+            fpatt = os.path.join(self.datadir, f"*rank{rank}_epoch{epoch:05d}_batch{batch:05d}_{add_str_tar}_token_infos_{var_aux}*.dat")
+            token_dict = self.target_token_config[var_aux]
+        else:
+            raise ValueError(f"Parsed token type '{token_type}' is unknown. Choose one of the following: {*self.known_data_types,}")
+        # Get file for token info
+        fname_tokinfos_now = self.get_token_file(fpatt)
+        
+        # read token info and reshape
+        shape_aux = shape_from_str(fname_tokinfos_now)
+        tokinfo_data = np.fromfile(fname_tokinfos_now, dtype=np.float32)
+        # hack for downscaling
+        if token_type in ["prediction", "target"] and "downscaling_num_layers" in self.config:
+            token_dict["num_tokens"] = [1, *token_dict["num_tokens"][1::]]
+        tokinfo_data = tokinfo_data.reshape(shape_aux[0], len(token_dict["vlevel"]), *token_dict["num_tokens"], shape_aux[-1])
+
+        if token_type in self.known_data_types[1:] and lmasked:
+            # read indices of masked tokens
+            fpatt = os.path.join(self.datadir, f"*rank{rank}_epoch{epoch:05d}_batch{batch:05d}_{add_str_tar2}_tokens_masked_idx_{var_aux}*.dat")
+            fname_tok_maskedidx = self.get_token_file(fpatt)
+            
+            tok_masked_idx = np.fromfile(fname_tok_maskedidx, dtype=np.int64)
+            # reshape token info for slicing...
+            tokinfo_data = tokinfo_data.transpose(1, 0, 2, 3, 4, 5).reshape( -1, shape_aux[-1])
+            # ... and slice to relevant data
+            tokinfo_data = tokinfo_data[tok_masked_idx]   
+        
+        return tokinfo_data
+    
+    def _get_date_nomask(self, tokinfo_data, nt):
+        
+        assert tokinfo_data.ndim == 6, f"Parsed token info data has {tokinfo_data.ndim} dimensions, but 6 are expected."
+        
+        times = np.array(np.floor(np.delete(tokinfo_data, np.s_[3:], axis = -1)), dtype = np.int32) # ?
+        times = times[:,0,:,0,0,:]                                     # remove redundant dimensions
+
+        years, days, hours = times[:,:,0].flatten(), times[:,:,1].flatten(), times[:,:,2].flatten()
+        years = np.array([int(y) for y in years])
+
+        # ML: Should not be required in the future!!!
+        if any(days < 0):
+            years = np.where(days < 0, years - 1, years)
+            days_new = np.array([pd.Timestamp(y, 12, 31).dayofyear - self.days_offset for y in years])
+            days = np.where(days < 0, days_new - days, days)
+
+        if any(days > 364):
+            days_per_year = np.array([pd.Timestamp(y, 12, 31).dayofyear for y in years]) - self.days_offset
+            years = np.where(days > days_per_year, years + 1, years)
+            days = np.where(days > days_per_year, days - days_per_year - self.days_offset, days)
+
+        # construct date information for centered data position of tokens
+        dates = doy_to_mo(days + self.days_offset, years)
+        dates = dates + pd.TimedeltaIndex(hours.flatten(), unit='h')
+        ### ML: Is this really required???
+        # appy hourly offset
+        dates = dates - pd.TimedeltaIndex(np.ones(dates.shape)*self.hours_offset, unit="h")
+        # reshape and construct remaining date information of tokens
+        dates = np.array(dates).reshape(times.shape[0:-1])
+        dates = np.array([list(dates + pd.TimedeltaIndex(np.ones(dates.shape)*hh, unit="h")) for hh in range(-int(nt/2), int(nt/2) + 1)])
+        dates = dates.transpose(1, 2, 0).reshape(times.shape[0], -1)
+
+        return dates
+
+        # times = np.array(np.floor(np.delete(tokinfo_data, np.s_[3:], axis = -1)), dtype = np.int32) # ?
+        # times = times[:,0,:,0,0,:]                                     # remove redundant dimensions
+
+        # years, days, hours = times[:,:,0].flatten(), times[:,:,1].flatten(), times[:,:,2].flatten()
+        # years = np.array([int(y) for y in years])
+        
+        # # ML: Should not be required in the future!!!
+        # if any(days < 0):
+        #     years = np.where(days < 0, years - 1, years)
+        #     days_new = np.array([pd.Timestamp(y, 12, 31).dayofyear - self.days_offset for y in years])
+        #     days = np.where(days < 0, days_new - days, days)
+
+        # if any(days > 364):
+        #     days_per_year = np.array([pd.Timestamp(y, 12, 31).dayofyear for y in years]) - self.days_offset
+        #     years = np.where(days > days_per_year, years + 1, years)
+        #     days = np.where(days > days_per_year, days - days_per_year - self.days_offset, days)
+        
+        # # construct date information for centered data position of tokens
+        # dates = doy_to_mo(days + self.days_offset, years)
+        # dates = dates + pd.TimedeltaIndex(hours.flatten(), unit='h')
+        # ### ML: Is this really required???
+        # dates = dates - pd.TimedeltaIndex(np.ones(dates.shape)*self.hours_offset, unit="h")
+        
+        # # reshape and construct remaining date information of tokens
+        # dates = np.array(dates).reshape(times.shape[0:-1]) 
+        # dates = np.array([dates + pd.TimedeltaIndex(np.ones(dates.shape)*hh, unit="h") for hh in range(-int(nt/2), int(nt/2) + 1)])
+        # print(dates.shape)
+        # print(times.shape[0])
+        # dates = dates.transpose(1, 2, 0).reshape(times.shape[0], -1)
+        
+        # return dates
+    
+    def _get_date_masked(self, tokinfo_data, nt):
+        
+        assert tokinfo_data.ndim == 2, f"Parsed token info data has {tokinfo_data.ndim} dimensions, but 2 are expected."
+        
+        years, days, hours = tokinfo_data[:, 0].flatten(), tokinfo_data[:, 1].flatten(), tokinfo_data[:, 2].flatten()
+        days = np.array(np.floor(days), dtype=np.int32)
+        # ML: Should not be required in the future!!!
+        if any(days < 0):
+            years = np.where(days < 0, years - 1, years)
+            days_new = np.array([pd.Timestamp(y, 12, 31).dayofyear - self.days_offset for y in years])
+            days = np.where(days < 0, days_new - days, days)
+
+        if any(days > 364):
+            days_per_year = np.array([pd.Timestamp(y, 12, 31).dayofyear for y in years]) - self.days_offset
+            years = np.where(days > days_per_year, years + 1, years)
+            days = np.where(days > days_per_year, days - days_per_year - self.days_offset, days)
+                          
+        dates = doy_to_mo(days + self.days_offset, years)      # add 1 to days since counting starts with zero in token_info
+        dates = dates + pd.TimedeltaIndex(hours, unit='h')
+        ### ML: Is this really required???
+        # appy hourly offset
+        dates = dates - pd.TimedeltaIndex(np.ones(dates.shape)*self.hours_offset, unit="h") 
+        
+        # reshape and construct remaining date information of tokens  
+        dates = np.array([dates + pd.TimedeltaIndex(np.ones(dates.shape)*hh, unit="h") for hh in range(-int(nt/2), int(nt/2) + 1)])
+        dates = dates.transpose()
+        
+        return np.array(dates)  
+        
+    def get_date(self, tokinfo_data, token_config):
+        """
+        Retrieve dates from token info data 
+        :param tokinfo_data: token info data which was read beforehand by _get_token_info-method
+        :param token_config: corresponding token configuration
+        """
+        nt = token_config["token_shape"][0]
+        
+        ndims_tokinfo = tokinfo_data.ndim
+        
+        if ndims_tokinfo == 2:
+            get_date_func = self._get_date_masked
+        elif ndims_tokinfo == 6:
+            get_date_func = self._get_date_nomask
+        else:
+            raise ValueError(f"Parsed tokinfo_data-array has unexpected number of dimensions ({ndims_tokinfo}).")
+        
+        dates = get_date_func(tokinfo_data, nt)       
+        
+        return dates
+    
+    def get_grid(self, tokinfo_data, token_config, dx):
+        """
+        Retrieve underlying geo/grid information.
+        :param tokinfo_data: token info data which was read beforehand by _get_token_info-method
+        :param token_config: corresponding token configuration
+        :param dx: spacing of underlying grid
+        """
+        ndims_tokinfo = tokinfo_data.ndim
+        
+        if ndims_tokinfo == 2:
+            get_grid_func = self._get_grid_masked
+        elif ndims_tokinfo == 6:
+            get_grid_func = self._get_grid_nomask
+        else:
+            raise ValueError(f"Parsed tokinfo_data-array has unexpected number of dimensions ({ndims_tokinfo}).")
+            
+        lats, lons = get_grid_func(tokinfo_data, token_config, dx)
+        
+        return lats, lons
+    
+    
+    def read_one_file(self, fname: str, token_type: str, varname: str, token_config: dict, token_info, dx: float, lmasked: bool = True, ldenormalize: bool = True,
+                     no_mean_denorm: bool = False):
+        """
+        Read data from a single output file of AtmoRep and convert to xarray DataArray with underlying coordinate information.
+        :param token_type: Type of token for which info should be retrieved (either 'input', 'target', 'ensemble' or 'prediction')
+        :param rank: rank (of job) that has created the requested token information file
+        :param epoch: training epoch of requested token information file
+        :param batch: batch of requested token information
+        :param varname: name of variable for which token info is requested
+        :param lmasked: flag if masking was applied on token (for token_type 'target' or 'prediction' only)
+        :param ldenormalize: flag if denormalize/invert correction should be applied (also includes inversion of log transformation)
+        :param no_mean_denorm: flag if mean should not be added when denormalization is performed
+        """              
+        data = np.fromfile(fname, dtype=np.float32)
+        
+        times = self.get_date(token_info, token_config)
+        lats, lons = self.get_grid(token_info, token_config, dx)
+        
+        if token_type in self.known_data_types[1:] and lmasked:
+            data = self._reshape_masked_data(data, token_config, token_type == "ensembles", self.config["net_tail_num_nets"])
+            vlvl = np.array(token_info[:, 3], dtype=np.int32)
+            data = self.masked_data_to_xarray(data, varname, times, vlvl, lats, lons, token_type == "ensembles")
+        else:
+            data = self._reshape_nomask_data(data, token_config, self.config["batch_size_test"])
+            data = self.nomask_data_to_xarray(data, varname, times, token_config["vlevel"], lats, lons)
+
+        if ldenormalize:
+            t0 = timer()
+
+            if getattr(token_config, "log_transform", False):
+                data = self.invert_log_transform(data)
+
+            if token_type in self.known_data_types[1:] and lmasked:
+                data = self.denormalize_masked_data(data, token_type, token_config["norm_type"], no_mean_denorm)
+            else:
+                data = self.denormalize_nomask_data(data, token_type, token_config["norm_type"], no_mean_denorm)   
+                
+            data.attrs["denormalization time [s]"] = timer() - t0
+        
+        return data
+    
+    def read_data(self, token_type: str, varname: str, rank: int = -1, epoch: int = -1, batch: int = -1, lmasked: bool = True,
+                  ldenormalize: bool = True, no_mean_denorm: bool = False):
+        """
+        Read data from a single output file of AtmoRep and convert to xarray DataArray with underlying coordinate information.
+        :param token_type: Type of token for which info should be retrieved (either 'input', 'target', 'ensemble' or 'prediction')
+        :param rank: rank (of job) that has created the requested token information file
+        :param epoch: training epoch of requested token information file
+        :param batch: batch of requested token information
+        :param varname: name of variable for which token info is requested
+        :param lmasked: flag if masking was applied on token (for token_type 'target' or 'prediction' only)
+        :param ldenormalize: flag if denormalize/invert correction should be applied (also includes inversion of log transformation)
+        :param no_mean_denorm: flag if mean should not be added when denormalization is performed
+        """      
+        if token_type == "source":
+            token_type_str = "source"
+            token_config = self.input_token_config[varname]
+            dx = self.dx_in
+        elif token_type in self.known_data_types[1:]:
+            token_type_str = "preds" if token_type == "prediction" else token_type
+            token_config = self.target_token_config[varname]
+            dx = self.dx_tar
+        else:
+            raise ValueError(f"Parsed token type '{token_type}' is unknown. Choose one of the following: {*self.known_data_types,}")
+            
+        filelist = self.get_hierarchical_sorted_files(token_type_str, varname, rank, epoch, batch)
+        
+        print(f"Start reading {len(filelist)} files for {token_type} data...")
+        lwarn = True
+        for i, f in enumerate(tqdm(filelist)):
+            rank, epoch, batch = self.get_rank_epoch_batch(f)
+            try:
+                token_info = self._get_token_info(token_type, rank=rank, epoch=epoch, batch=batch, varname=varname, lmasked=lmasked)
+            except FileNotFoundError:
+                if lwarn:   # print warning only once
+                    print(f"No token info for {token_type} data of {varname} found. Proceed with token info for input data.")
+                    lwarn = False
+                token_info = self._get_token_info(self.known_data_types[0], rank=rank, epoch=epoch, batch=batch, varname=None, lmasked=lmasked)
+                
+            da_f = self.read_one_file(f, token_type, varname, token_config, token_info, dx, lmasked, ldenormalize, no_mean_denorm)
+            
+            if i == 0:
+                da_list = []
+                denorm_time = 0.
+                da_list.append(da_f.copy())
+                dim_concat = da_f.dims[0]
+            else:
+                ilast = da_list[i-1][dim_concat][-1] + 1
+                inew = np.arange(ilast, ilast + len(da_f[dim_concat]))
+                da_f = da_f.assign_coords({dim_concat: inew})
+                if ldenormalize:
+                    denorm_time += da_f.attrs["denormalization time [s]"]
+                     
+                da_list.append(da_f.copy())
+                #da = xr.concat([da, da_f], dim=da.dims[0])
+
+        da = xr.concat(da_list, dim=da_list[0].dims[0])
+        if ldenormalize:
+            da.attrs['denormalization time [s]'] = denorm_time
+            print(f"Denormalization of {len(filelist)} files for {token_type} data took {da.attrs['denormalization time [s]']:.2f}s")
+                
+        return da
+        
+    def get_hierarchical_sorted_files(self, token_type_str: str, varname: str, rank: int = -1, epoch: int = -1, batch: int = -1):
+        rank_str = f"rank*" if rank == -1 else f"rank{rank:d}"
+        epoch_str = f"epoch*" if epoch == -1 else f"epoch{epoch:05d}"
+        batch_str = f"batch*" if batch == -1 else f"batch{batch:05d}"
+        
+        fpatt = f"*{self.model_id}_{rank_str}_{epoch_str}_{batch_str}_{token_type_str}_{varname}*.dat"
+        filelist = glob.glob(os.path.join(self.datadir, fpatt))
+
+        filelist = [f for f in filelist if '000-1' not in f] #remove epoch -1
+
+        if len(filelist) == 0:
+            raise FileNotFoundError(f"Could not file any files mathcing pattern '{fpatt}' under directory '{self.datadir}'.")
+        
+        # hierarchical sorting: epoch -> rank -> batch
+        sorted_filelist = sorted(filelist, key=lambda x: self.get_number(x, "_rank"))
+        sorted_filelist = sorted(sorted_filelist, key=lambda x: self.get_number(x, "_epoch"))
+        sorted_filelist = sorted(sorted_filelist, key=lambda x: self.get_number(x, "_batch"))
+      
+        return sorted_filelist
+    
+    def denormalize_global(self, da, param_dir, no_mean = False):
+        
+        da_dims = list(da.dims)
+        varname = da.name
+        vlvl = da["vlevel"].values #list(da["vlevel"].values)[0]
+        # get nomralization parameters
+        mean, std = self.get_global_norm_params(varname, vlvl, param_dir)
+        if no_mean:
+            mean[...] = 0.
+
+        # re-index data along time dimension for efficient denormalization
+        time_dims = list(da["time"].dims)
+        da = da.stack({"time_aux": time_dims})
+        time_save = da["time_aux"]                # save information for later re-indexing
+        da = da.set_index({"time_aux": "time"}).sortby("time_aux")
+        time_save = time_save.sortby("time")
+
+        da = da.resample({"time_aux": "1M"})
+        # loop over year-month items
+        for i, (ts, da_mm) in enumerate(da):
+            yr_mm = pd.to_datetime(ts).strftime("%Y-%m")
+            da_mm = da_mm * std.sel({"year_month": yr_mm}).values + mean.sel({"year_month": yr_mm}).values
+            if i == 0:
+                da_concat = da_mm.copy()
+            else:
+                da_concat = xr.concat([da_concat, da_mm], dim="time_aux")
+
+        da_concat["time_aux"] = time_save
+        if not xr.__version__ == "0.20.1":
+            da_concat = da_concat.reset_index("time_aux")
+        da_concat = da_concat.unstack("time_aux")
+
+        return da_concat.transpose(*da_dims)
+
+    def denormalize_masked_data(self, data: xr.DataArray, token_type: str, norm_type: str, no_mean: bool = False):
+        """
+        Denormalizes/Inverts correction for masked data.
+        Data has to normalized considering vertical level and time which both vary along token-dimension.
+        :param data: normalized (xarray) data array providing masked data (cf. masked_data_to_xarray-method) 
+        :param token_type: type of token to be handled, e.g. 'source' (cf. known_data_types)
+        :param norm_type: type of normalization applied to the data (either 'local' or 'global')
+        :param no_mean: flag if data normalization has NOT been zero-meaned 
+        """
+        mm_yr = np.unique(data["time"].dt.strftime("y%Y_m%m"))
+        vlevels = np.unique(data["vlevel"])
+        varname = data.name
+        dim0 = data.dims[0]
+        basedir = self.datadir_input if token_type == self.known_data_types[0] else self.datadir_target
+
+        for vlvl in vlevels:
+            if norm_type == "local":
+                datadir = os.path.join(basedir, f"{vlvl}", "corrections", varname)
+                for month in mm_yr:
+                    fcorr_now = os.path.join(datadir, f"corrections_mean_var_{varname}_{month}_ml{vlvl}.nc")
+                    norm_data = xr.open_dataset(fcorr_now)
+                    mean, std = norm_data[f"{varname}_ml{vlvl:d}_mean"], norm_data[f"{varname}_ml{vlvl:d}_std"]
+                    if no_mean: mean[...] = 0.
+
+                    mask = (data["vlevel"] == vlvl) & (data["time"].dt.strftime("y%Y_m%m") == month)
+                    data_now = data.where(mask, drop=True)
+                    for it in data_now[dim0]:   # loop over all tokens
+                        xy_dict = {"lat": data_now.sel({dim0: it})["lat"], "lon": data_now.sel({dim0: it})["lon"]}
+                        mu_it, std_it = mean.sel(xy_dict), std.sel(xy_dict)
+                        data_now.loc[{dim0: it}] = data_now.loc[{dim0: it}]*std_it + mu_it
+                    data = xr.where(mask, data_now, data)
+                        
+            elif norm_type == "global":
+                mask = data["vlevel"] == vlvl
+                data = xr.where(mask, self.denormalize_global(data.where(mask), basedir, no_mean = no_mean), data)
+                    
+        return data
+    
+    def denormalize_nomask_data(self, data: xr.DataArray, token_type:str, norm_type:str, no_mean: bool = False):
+        """
+        Denormalizes/Inverts correction for unmasked data.
+        Data has to normalized considering vertical level and time which both vary along token-dimension.
+        :param data: normalized (xarray) data array providing unmasked data (cf. nomask_data_to_xarray-method) 
+        :param token_type: type of token to be handled, e.g. 'source' (cf. known_data_types)
+        :param norm_type: type of normalization applied to the data (either 'local' or 'global')
+        :param no_mean: flag if data normalization has NOT been zero-meaned 
+        """ 
+        times = data["time"]
+        nt = len(times[0,:])
+        dim0 = data.dims[0]
+        varname = data.name
+        basedir = self.datadir_input if token_type == self.known_data_types[0] else self.datadir_target
+        
+        center_times = times.min(dim="t") + pd.Timedelta(nt/2, "hours")
+        yr_mm = np.unique(center_times.dt.strftime("y%Y_m%m"))
+
+        for vlvl in data["vlevel"]:
+            if norm_type == "local":
+                iv = vlvl.values 
+                data_dir = os.path.join(basedir, f"{iv:d}", "corrections", varname)
+                for month in yr_mm:
+                    fcorr_now = os.path.join(data_dir, f"corrections_mean_var_{varname}_{month}_ml{iv:d}.nc")
+                    norm_data = xr.open_dataset(fcorr_now)
+                    mean, std = norm_data[f"{varname}_ml{iv:d}_mean"], norm_data[f"{varname}_ml{iv:d}_std"]
+                    if no_mean: mean[...] = 0.
+
+                    mask = (data["time"].dt.strftime("y%Y_m%m") == month)
+                    data_now = data.where(mask, drop=True)
+                    for it in data_now[dim0]:
+                        xy_dict = {"lat": data_now.sel({dim0: it})["lat"], "lon": data_now.sel({dim0: it})["lon"]}
+                        mu_it, std_it = mean.sel(xy_dict), std.sel(xy_dict)
+                        data.loc[{dim0: it, "vlevel": iv}] = data.loc[{dim0: it, "vlevel": iv}]*std_it + mu_it 
+
+            elif norm_type == "global": 
+                data.loc[{"vlevel": vlvl}] = self.denormalize_global(data.sel({"vlevel": vlvl}), basedir, no_mean = no_mean)  
+                    
+        return data
+
+    def invert_log_transform(self, data):
+        """
+        Inverts log transformation on data.
+        param data: the xarray DataArray which was log transformed
+        :return: data after inversion of log transformation
+        """
+        data = self.epsilon*(np.exp(data) - 1.)
+        
+        return data
+    
+    def get_rank_epoch_batch(self, fname, to_int: bool=True):
+        rank = self.get_number(fname, "_rank")
+        epoch = self.get_number(fname, "_epoch")
+        batch = self.get_number(fname, "_batch")
+        
+        if to_int:
+            rank, epoch, batch = int(rank), int(epoch), int(batch)
+        
+        return rank, epoch, batch
+    
+    @staticmethod
+    def nomask_data_to_xarray(data_np, varname: str, times, vlvls, lat, lon, lensemble: bool = False):
+        
+        if lensemble:
+            raise ValueError("Ensemlbe not supported yet.")
+        
+        nbatch, nt = data_np.shape[0], data_np.shape[2]
+        nlat, nlon = len(lat[0, :]), len(lon[0, :])
+        
+        da = xr.DataArray(data_np, dims=["ibatch", "vlevel", "t", "y", "x"],
+                          coords={"ibatch": np.arange(nbatch), "vlevel": vlvls,
+                                  "t": np.arange(nt), "y": np.arange(nlat), "x": np.arange(nlon),
+                                  "time": (["ibatch", "t"], times), 
+                                  "lat": (["ibatch", "y"], lat), "lon": (["ibatch", "x"], lon)},
+                         name=varname)
+        return da
+    
+    @staticmethod
+    def masked_data_to_xarray(data_np, varname: str, times, vlvls, lat, lon, lensemble: bool = False):
+            
+        data_dims = ["itoken", "tt", "yt", "xt"]        
+        if lensemble:
+            ntoken, nens, ntt, yt, xt = data_np.shape
+            data_dims.insert(1, "ens")
+        else:
+            ntoken, ntt, yt, xt = data_np.shape
+
+        data_coords = {"itoken": np.arange(ntoken), "tt": np.arange(ntt), "yt": np.arange(yt), 
+                       "xt": np.arange(xt), "time": (["itoken", "tt"], times),
+                       "vlevel": ("itoken", vlvls),
+                       "lat": (["itoken", "yt"], lat), "lon": (["itoken", "xt"], lon)}        
+        
+        if lensemble:
+            data_coords["ens"] = np.arange(nens)
+        
+        da = xr.DataArray(data_np, dims=data_dims, coords=data_coords, name = varname)
+
+        return da
+    
+    @staticmethod
+    def get_number(file_name, split_arg):
+        # Extract the number from the file name using the provided split argument
+        return int(file_name.split(split_arg)[1].split('_')[0])
+    
+    @staticmethod
+    def get_token_file(fpatt: str, nfiles: int = 1):
+        # Get file for token info
+        fnames = glob.glob(fpatt)
+        # sanity check
+        if not fnames:
+            raise FileNotFoundError(f"Could not find required file(-s) using the following filename pattern: '{fpatt}'")
+            
+        assert len(fnames) == nfiles, f"More files matching filename pattern '{fpatt}' found than expected."
+        
+        if nfiles == 1: 
+            fnames = fnames[0]
+        
+        return fnames    
+    
+    @staticmethod
+    def _get_grid_nomask(tokinfo_data, token_config, dx):
+        """
+        Retrieve underlying geo/grid information for unmasked data (complete patches!)
+        :param tokinfo_data: token info data which was read beforehand by _get_token_info-method
+        :param token_config: corresponding token configuration
+        :param dx: spacing of underlying grid
+        """
+        # retrieve spatial token size
+        ny, nx = token_config["token_shape"][1:3]
+
+        # off-centering for even number of grid points
+        ny1, nx1 = 1, 1
+        
+        if ny%2 == 0: 
+            ny1 = 0
+        if nx%2 == 0:
+            nx1 = 0
+        
+        lat_offset = np.arange(-int((ny-ny1)/2)*dx, int(ny/2+ny1)*dx, dx)
+        lon_offset = np.arange(-int((nx-nx1)/2)*dx, int(nx/2+nx1)*dx, dx)   
+        
+        #IMPORTANT: need to swap axes in lats to have ntokens_lat adjacent to tokinfo --> 8x4x8x8 -> 8x8x4x8
+        lons  = np.array([tokinfo_data.swapaxes(0, -1)[-3]+lon_offset[i] for i in range(len(lon_offset))]).swapaxes(0, -1)%360
+        lats  = np.array([tokinfo_data.swapaxes(-3, -2).swapaxes(0, -1)[-4]+lat_offset[i] for i in range(len(lat_offset))]).swapaxes(0, -1)%180
+        # if(flip_lats):
+        #     lats  = np.flip(lats)
+        
+        lats, lons = lats[:, 0, 0, 0, :, :], lons[:, 0, 0, 0, :, :]
+        lats, lons = lats.reshape(lats.shape[0], -1), lons.reshape(lats.shape[0], -1)
+        
+        # correct lat values because they are in 'mathematical coordinates' with 0 at the North Pole
+        lats = 90. -lats
+        
+        return lats, lons  
+    
+    @staticmethod
+    def _get_grid_masked(tokinfo_data, token_config, dx):  
+        """
+        Retrieve underlying geo/grid information for masked data (scattered tokens!)
+        :param tokinfo_data: token info data which was read beforehand by _get_token_info-method
+        :param token_config: corresponding token configuration
+        :param dx: spacing of underlying grid
+        """
+        # retrieve spatial token size
+        ntok = tokinfo_data.shape[0]
+        ny, nx = token_config["token_shape"][1:3]
+        
+        # off-centering for even number of grid points
+        ny1, nx1 = 1, 1
+        if ny%2 == 0: 
+            ny1 = 0
+        if nx%2 == 0:
+            nx1 = 0
+        
+        lat_offset = np.arange(-int((ny-ny1)/2+nx1)*dx, int(ny/2+ny1)*dx, dx)
+        lon_offset = np.arange(-int((nx-nx1)/2+ny1)*dx, int(nx/2+nx1)*dx, dx) 
+        
+        #lats = np.array([tokinfo_data[idx, 4] + np.arange(-int(ny/2)*dx, int(ny/2+1)*dx, dx) for idx in range(ntok)]) % 180 #boundary conditions
+        #lons = np.array([tokinfo_data[idx, 5] + np.arange(-int(nx/2)*dx, int(nx/2+1)*dx, dx) for idx in range(ntok)]) % 360 #boundary conditions
+        lats = np.array([tokinfo_data[idx, 4] + lat_offset for idx in range(ntok)]) % 180 #boundary conditions
+        lons = np.array([tokinfo_data[idx, 5] + lon_offset for idx in range(ntok)]) % 360 #boundary conditions
+        
+        # correct lat values because they are in 'mathematical coordinates' with 0 at the North Pole
+        lats = 90. - lats
+        
+        return lats, lons
+    
+
+    @staticmethod
+    def get_global_norm_params(varname, vlv, basedir):
+        """
+        Read parameter files for global z-score normalization
+        :param varname: name of variable
+        :param vlv: vertical level index
+        :param basedir: base directory under which correction/parameter files are located
+        """
+        fcorr = os.path.join(basedir, f"{vlv}", "corrections", f"global_corrections_mean_var_{varname}_ml{vlv:d}.bin")
+        corr_data = np.fromfile(fcorr, dtype="float32").reshape(-1, 4)
+
+        years, months = corr_data[:,0], corr_data[:,1]
+
+        # hack: filter data where years equal to zero
+        bad_inds = np.nonzero(years == 0)
+        years, months = np.delete(years, bad_inds), np.delete(months, bad_inds)
+        mean, var = np.delete(corr_data[:,2], bad_inds), np.delete(corr_data[:,3], bad_inds)
+
+        yr_mm = [pd.to_datetime(f"{int(yr):d}-{int(m):02d}", format="%Y-%m") for yr, m in zip(years, months)]
+
+        mean = xr.DataArray(mean, dims=["year_month"], coords={"year_month": yr_mm})
+        var = xr.DataArray(var, dims=["year_month"], coords={"year_month": yr_mm})
+
+        return mean, var    
+    
+    @staticmethod
+    def _reshape_nomask_data(data, token_config: dict, batch_size: int):
+        """
+        Reshape unmasked token data that has been read from .dat-files of AtmoRep (complete patches!).
+        Data is assumed to cover 
+        Adapted from Ilaria, but now with in-place operations to save memory.
+        """
+        sh = (batch_size, len(token_config["vlevel"]), *token_config["num_tokens"], *token_config["token_shape"])
+        data = data.reshape(*sh)
+        
+        # further reshaping to collapse token dimensions 
+        # The following is adapted from Ilaria, but now with in-place operations to save memory (original code, see below).
+        data = np.transpose(data, (0,1,2,5,3,6,4,7))
+        data = data.reshape(*data.shape[:-2], -1)
+        data = data.reshape(*data.shape[:-3], -1, *data.shape[-1:])
+        data = data.reshape(*data.shape[:2], -1, *data.shape[4:])
+
+        return data
+    
+    @staticmethod
+    def _reshape_masked_data(data, token_config, lensemble: bool = False, nens: int = None):
+        sh0 = (-1,)
+        if lensemble:
+            assert nens > 0, f"Invalid nens value passed ({nens}). It must be an integer > 0"
+            sh0 = (-1, nens)
+
+        token_sh = token_config["token_shape"]
+        data = data.reshape(*sh0, *token_sh)
+        
+        return data