diff --git a/downscaling_ap5/HPC_batch_scripts/train_unet_model_template_e4.sh b/downscaling_ap5/HPC_batch_scripts/train_unet_template_e4.sh
similarity index 88%
rename from downscaling_ap5/HPC_batch_scripts/train_unet_model_template_e4.sh
rename to downscaling_ap5/HPC_batch_scripts/train_unet_template_e4.sh
index 3b0718cfde2cef65982e29e343878a9d34d75ab6..5b5539788018e32e2380a14a36cc0d3a2321c690 100644
--- a/downscaling_ap5/HPC_batch_scripts/train_unet_model_template_e4.sh
+++ b/downscaling_ap5/HPC_batch_scripts/train_unet_template_e4.sh
@@ -9,8 +9,8 @@
 #SBATCH --gres=gpu:1
 ##SBATCH --mem=40G
 #SBATCH --time=01:00:00
-#SBATCH --output=train_wgan-model-out.%j
-#SBATCH --error=train_wgan-model-err.%j
+#SBATCH --output=train_wgan-out.%j
+#SBATCH --error=train_wgan-err.%j
 
 ######### Template identifier (don't remove) #########
 echo "Do not run the template scripts"
@@ -47,9 +47,9 @@ export PYTHONPATH=${BASE_DIR}/models:$PYTHONPATH
 export PYTHONPATH=${BASE_DIR}/postprocess:$PYTHONPATH
 echo ${PYTHONPATH}
 
-# data-directories 
-# Note template uses Tier2-dataset. Adapt accordingly for other datasets.
-indir=/data/maelstrom/langguth1/tier2/train
+# data-directories
+# Adapt accordingly to your dataset
+indir=<my_input_dir>
 outdir=${BASE_DIR}/trained_models/
 js_model_conf=${BASE_DIR}/config/config_unet.json
 js_ds_conf=${BASE_DIR}/config/config_ds_tier2.json
diff --git a/downscaling_ap5/HPC_batch_scripts/train_unet_model_template_jsc.sh b/downscaling_ap5/HPC_batch_scripts/train_unet_template_jsc.sh
similarity index 83%
rename from downscaling_ap5/HPC_batch_scripts/train_unet_model_template_jsc.sh
rename to downscaling_ap5/HPC_batch_scripts/train_unet_template_jsc.sh
index 7142b664561a493294719e2b8c6f4809665420a4..f14969a7d76fd576961d8ed070870e4ac8a6d63c 100644
--- a/downscaling_ap5/HPC_batch_scripts/train_unet_model_template_jsc.sh
+++ b/downscaling_ap5/HPC_batch_scripts/train_unet_template_jsc.sh
@@ -4,9 +4,10 @@
 #SBATCH --ntasks=1
 ##SBATCH --ntasks-per-node=1
 #SBATCH --cpus-per-task=48
-#SBATCH --output=train_unet-model-out.%j
-#SBATCH --error=train_unet-model-err.%j
+#SBATCH --output=train_unet-out.%j
+#SBATCH --error=train_unet-err.%j
 #SBATCH --time=02:00:00
+##SBATCH --time=20:00:00
 #SBATCH --gres=gpu:1
 ##SBATCH --partition=batch
 ##SBATCH --partition=gpus
@@ -21,8 +22,9 @@ echo "Do not run the template scripts"
 exit 99
 ######### Template identifier (don't remove) #########
 
-# pin CPUs (needed for Slurm>22.05)
-export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK}
+# environmental variables to support cpus_per_task with Slurm>22.05
+export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}
+export SRUN_CPUS_PER_TASK="${SLURM_CPUS_PER_TASK}"
 
 # basic directories
 WORK_DIR=$(pwd)
@@ -46,7 +48,8 @@ if [ -z ${VIRTUAL_ENV} ]; then
 fi
 
 
-# data-directories 
+# data-directories
+# Adapt accordingly to your dataset
 indir=<training_data_dir>
 outdir=${BASE_DIR}/trained_models/
 js_model_conf=${BASE_DIR}/config/config_unet.json
diff --git a/downscaling_ap5/HPC_batch_scripts/train_wgan_model_template_e4.sh b/downscaling_ap5/HPC_batch_scripts/train_wgan_template_e4.sh
similarity index 86%
rename from downscaling_ap5/HPC_batch_scripts/train_wgan_model_template_e4.sh
rename to downscaling_ap5/HPC_batch_scripts/train_wgan_template_e4.sh
index 52669e97bee835e7d2c6eba654926779945cd02d..d370c29f896bef32948c1d9e02c02950dcc6c492 100644
--- a/downscaling_ap5/HPC_batch_scripts/train_wgan_model_template_e4.sh
+++ b/downscaling_ap5/HPC_batch_scripts/train_wgan_template_e4.sh
@@ -9,8 +9,8 @@
 #SBATCH --gres=gpu:1
 ##SBATCH --mem=40G
 #SBATCH --time=01:00:00
-#SBATCH --output=train_wgan-model-out.%j
-#SBATCH --error=train_wgan-model-err.%j
+#SBATCH --output=train_wgan-out.%j
+#SBATCH --error=train_wgan-err.%j
 
 ######### Template identifier (don't remove) #########
 echo "Do not run the template scripts"
@@ -47,10 +47,10 @@ export PYTHONPATH=${BASE_DIR}/models:$PYTHONPATH
 export PYTHONPATH=${BASE_DIR}/postprocess:$PYTHONPATH
 echo ${PYTHONPATH}
 
-# data-directories 
-# Note template uses Tier2-dataset. Adapt accordingly for other datasets.
-indir=/data/maelstrom/langguth1/tier2/train
-outdir=${BASE_DIR}/tranied_models/
+# data-directories
+# Adapt accordingly to your dataset
+indir=<my_input_dir>
+outdir=${BASE_DIR}/trained_models/
 js_model_conf=${BASE_DIR}/config/config_wgan.json
 js_ds_conf=${BASE_DIR}/config/config_ds_tier2.json
 
diff --git a/downscaling_ap5/HPC_batch_scripts/train_wgan_model_template_jsc.sh b/downscaling_ap5/HPC_batch_scripts/train_wgan_template_jsc.sh
similarity index 83%
rename from downscaling_ap5/HPC_batch_scripts/train_wgan_model_template_jsc.sh
rename to downscaling_ap5/HPC_batch_scripts/train_wgan_template_jsc.sh
index ac3986d62f941799865cdc65d4720cf3e0e11c21..cef8bee4a413dc8df53e9ec04347add8a0ffbd7f 100644
--- a/downscaling_ap5/HPC_batch_scripts/train_wgan_model_template_jsc.sh
+++ b/downscaling_ap5/HPC_batch_scripts/train_wgan_template_jsc.sh
@@ -4,9 +4,10 @@
 #SBATCH --ntasks=1
 ##SBATCH --ntasks-per-node=1
 #SBATCH --cpus-per-task=48
-#SBATCH --output=train_wgan-model-out.%j
-#SBATCH --error=train_wgan-model-err.%j
+#SBATCH --output=train_wgan-out.%j
+#SBATCH --error=train_wgan-err.%j
 #SBATCH --time=02:00:00
+##SBATCH --time=20:00:00
 #SBATCH --gres=gpu:1
 ##SBATCH --partition=batch
 ##SBATCH --partition=gpus
@@ -21,8 +22,9 @@ echo "Do not run the template scripts"
 exit 99
 ######### Template identifier (don't remove) #########
 
-# pin CPUs (needed for Slurm>22.05)
-export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK}
+# environmental variables to support cpus_per_task with Slurm>22.05
+export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}
+export SRUN_CPUS_PER_TASK="${SLURM_CPUS_PER_TASK}"
 
 # basic directories
 WORK_DIR=$(pwd)
@@ -46,7 +48,8 @@ if [ -z ${VIRTUAL_ENV} ]; then
 fi
 
 
-# data-directories 
+# data-directories
+# Adapt accordingly to your dataset
 indir=<training_data_dir>
 outdir=${BASE_DIR}/trained_models/
 js_model_conf=${BASE_DIR}/config/config_wgan.json
diff --git a/downscaling_ap5/config/config_ds_tier1.json b/downscaling_ap5/config/config_ds_tier1.json
index 906ce5e2e4e75cb7709b1f8d245322a3fdce2b75..495c15e7b27071d8f347b02f01f88b309313db13 100644
--- a/downscaling_ap5/config/config_ds_tier1.json
+++ b/downscaling_ap5/config/config_ds_tier1.json
@@ -2,5 +2,6 @@
   "norm_dims": ["time", "lat", "lon"],
   "batch_size": 32,
   "var_tar2in": "z_tar",
-  "laugmented": true
+  "laugmented": true,
+  "predictands": ["t2m_tar", "z_tar"]
 }
diff --git a/downscaling_ap5/config/config_ds_tier2.json b/downscaling_ap5/config/config_ds_tier2.json
index 96c38a9d362b7a7490eb69dc877054d3ebf22986..e47fc5f6efb8d97d5a9b72ca54d930b92312373c 100644
--- a/downscaling_ap5/config/config_ds_tier2.json
+++ b/downscaling_ap5/config/config_ds_tier2.json
@@ -1,5 +1,6 @@
 {
   "norm_dims": ["time", "rlat", "rlon"],
   "batch_size": 32,
-  "var_tar2in": "hsurf_tar"
+  "var_tar2in": "hsurf_tar",
+  "predictands": ["t_2m_tar", "hsurf_tar"]
 }
diff --git a/downscaling_ap5/config/config_ds_tier2_wind.json b/downscaling_ap5/config/config_ds_tier2_wind.json
new file mode 100644
index 0000000000000000000000000000000000000000..9591a95ef1502d85659e42f285a8ff5b2bce4ef7
--- /dev/null
+++ b/downscaling_ap5/config/config_ds_tier2_wind.json
@@ -0,0 +1,6 @@
+{
+  "norm_dims": ["time", "rlat", "rlon"],
+  "batch_size": 32,
+  "var_tar2in": "hsurf_tar",
+  "predictands": ["u_10m_tar", "v_10m_tar", "hsurf_tar"]
+}
diff --git a/downscaling_ap5/handle_data/handle_data_class.py b/downscaling_ap5/handle_data/handle_data_class.py
index 1abf0507e0ff562700bbe3d905451f5e5aa86861..4849de9bf506ca5e3b46d9e892d1ef38b56d765e 100644
--- a/downscaling_ap5/handle_data/handle_data_class.py
+++ b/downscaling_ap5/handle_data/handle_data_class.py
@@ -5,7 +5,7 @@
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-01-20"
-__update__ = "2023-02-07"
+__update__ = "2023-04-17"
 
 import os, glob
 from typing import List
@@ -128,34 +128,55 @@ class HandleDataClass(object):
         return da
 
     @staticmethod
-    def split_in_tar(da: xr.DataArray, target_var: str = "t2m") -> (xr.DataArray, xr.DataArray):
+    def split_in_tar(da: xr.DataArray, predictands: List = None, predictors: List = None) -> (xr.DataArray, xr.DataArray):
         """
         Split data array with variables-dimension into input and target data for downscaling
         :param da: The unsplitted data array
         :param target_var: Name of target variable which should consttute the first channel
+        :param predictands: List of selected predictand variables; parse None to use
+                            all predictands (vars with suffix _tar)
+        :param predictors: List of selected predictor variables; parse None to use all predictors (vars with suffix _in)
         :return: The split data array.
         """
-        invars = [var for var in da["variables"].values if var.endswith("_in")]
-        tarvars = [var for var in da["variables"].values if var.endswith("_tar")]
+        da_vars = list(da["variables"].values)
 
-        # ensure that ds_tar has a channel coordinate even in case of single target variable
-        roll = False
-        if len(tarvars) == 1:
-            sl_tarvars = tarvars
+        if predictors is None:
+            invars = [var for var in da_vars if var.endswith("_in")]
         else:
-            sl_tarvars = slice(*tarvars)
-            if tarvars[0] != target_var:     # ensure that target variable appears as first channel
-                roll = True
-
-        da_in, da_tar = da.sel({"variables": invars}), da.sel(variables=sl_tarvars)
-        if roll: da_tar = da_tar.roll(variables=1, roll_coords=True)
+            assert all([predictor in da_vars for predictor in
+                        predictors]), f"At least one predictor is not a data variable. Available variables are {*da_vars,}"
+            invars = list(predictors)
+        if predictands is None:
+            tarvars = [var for var in da_vars if var.endswith("_tar")]
+        else:
+            assert all([predictand in da_vars for predictand in
+                        predictands]), f"At least one predictor is not a data variable. Available variables are {*da_vars,}"
+            tarvars = list(predictands)
+
+        # DEPRECATED CODE #
+        # Unnecessary as predictands-list must be parsed to all data stream-methods
+        # (make_tf_dataset_dyn and make_tf_dataset_allmem)
+        # # ensure that ds_tar has a channel coordinate even in case of single target variable
+        # roll = False
+        # if len(tarvars) == 1:
+        #     sl_tarvars = tarvars
+        # else:
+        #     # sl_tarvars = slice(*tarvars)
+        #     sl_tarvars = tarvars
+        #     if tarvars[0] != target_var and predictands is None:  # ensure that target variable appears as first channel
+        #         roll = True
+        #
+        # da_in, da_tar = da.sel({"variables": invars}), da.sel(variables=list(sl_tarvars))
+        # if roll: da_tar = da_tar.roll(variables=1, roll_coords=True)
+        # DEPRECATED CODE #
+        da_in, da_tar = da.sel({"variables": invars}), da.sel({"variables": tarvars})
 
         return da_in, da_tar
 
     @staticmethod
     def make_tf_dataset_dyn(datadir: str, file_patt: str, batch_size: int, nepochs: int, nfiles2merge: int,
-                            lshuffle: bool = True, named_targets: bool = False, predictands: List = None,
-                            predictors: List = None, var_tar2in: str = None, norm_obj=None, norm_dims: List = None,
+                            predictands: List, predictors: List = None, lshuffle: bool = True,
+                            named_targets: bool = False, var_tar2in: str = None, norm_obj=None, norm_dims: List = None,
                             nworkers: int = 10):
         """
         Build TensorFlow dataset by streaming from netCDF using xarray's open_mfdatset-method.
@@ -166,10 +187,10 @@ class HandleDataClass(object):
         :param batch_size: desired mini-batch size
         :param nepochs: (effective) number of epochs for training
         :param nfiles2merge: number if files to merge for streaming
+        :param predictands: List of selected predictand variables
+        :param predictors: List of selected predictor variables; parse None to use all predictors (vars with suffix _in)
         :param lshuffle: boolean to enable sample shuffling
         :param named_targets: boolean if targets will be provided as dictionary with named variables for data stream
-        :param predictors: List of selected predictor variables; parse None to use all data
-        :param predictands: List of selected predictor variables; parse None to use all data
         :param var_tar2in: name of target variable to be added to input (used e.g. for adding high-resolved topography
                                                                          to the input)
         :param norm_dims: names of dimension over which normalization is applied. Should be None if norm_obj is parsed
@@ -186,9 +207,9 @@ class HandleDataClass(object):
 
         if norm_obj: assert isinstance(norm_obj, Normalize), "norm_obj is not an instance of the Normalize-class."
 
-        ds_obj = StreamMonthlyNetCDF(datadir, file_patt, nfiles_merge=nfiles2merge, selected_predictors=predictors,
-                                     selected_predictands=predictands, var_tar2in=var_tar2in,
-                                     norm_obj=norm_obj, norm_dims=norm_dims, nworkers=nworkers)
+        ds_obj = StreamMonthlyNetCDF(datadir, file_patt, nfiles_merge=nfiles2merge, selected_predictands=predictands,
+                                     selected_predictors=predictors, var_tar2in=var_tar2in, norm_obj=norm_obj,
+                                     norm_dims=norm_dims, nworkers=nworkers)
 
         tf_read_nc = lambda ind_set: tf.py_function(ds_obj.read_netcdf, [ind_set], tf.int64)
         tf_choose_data = lambda il: tf.py_function(ds_obj.choose_data, [il], tf.bool)
@@ -218,30 +239,35 @@ class HandleDataClass(object):
         return ds_obj, tfds
 
     @staticmethod
-    def make_tf_dataset_allmem(da: xr.DataArray, batch_size: int, lshuffle: bool = True, shuffle_samples: int = 20000,
-            named_targets: bool = False, var_tar2in: str = None, lrepeat: bool = True, drop_remainder: bool = True, 
-            lembed: bool = False) -> tf.data.Dataset:
+    def make_tf_dataset_allmem(da: xr.DataArray, batch_size: int, predictands: List, predictors: List = None,
+                               lshuffle: bool = True, shuffle_samples: int = 20000, named_targets: bool = False,
+                               var_tar2in: str = None, lrepeat: bool = True, drop_remainder: bool = True,
+                               lembed: bool = False) -> tf.data.Dataset:
         """
         Build-up TensorFlow dataset from a generator based on the xarray-data array.
         NOTE: All data is loaded into memory
         :param da: the data-array from which the dataset should be cretaed. Must have dimensions [time, ..., variables].
                    Input variable names must carry the suffix '_in', whereas it must be '_tar' for target variables
         :param batch_size: number of samples per mini-batch
+        :param predictands: List of selected predictand variables
+        :param predictors: List of selected predictor variables; parse None to use all predictors (vars with suffix _in)
         :param lshuffle: flag if shuffling should be applied to dataset
         :param shuffle_samples: number of samples to load before applying shuffling
         :param named_targets: flag if target of TF dataset should be dictionary with named target variables
         :param var_tar2in: name of target variable to be added to input (used e.g. for adding high-resolved topography
                                                                          to the input)
         :param lrepeat: flag if dataset should be repeated
-        :param drop_remaineder: flag if samples will be dropped in case batch size is not a divisor of # data samples 
+        :param drop_remainder: flag if samples will be dropped in case batch size is not a divisor of # data samples
         :param lembed: flag to trigger temporal embedding (not implemented yet!)
         """
         da = da.load()
-        da_in, da_tar = HandleDataClass.split_in_tar(da)
+        da_in, da_tar = HandleDataClass.split_in_tar(da, predictands=predictands, predictors=predictors)
         if var_tar2in is not None:
-            # NOTE: The following operation order must be the same as in the  read_netcdf-method of
-            #       the StreamMonthlyNetCDF-class!
-            da_in = xr.concat([da_in, da_tar.sel({"variables": var_tar2in})], "variables")
+            # NOTE: * The order of the following operation must be the same as in StreamMonthlyNetCDF.getitems
+            #       * The following operation order must concatenate var_tar2in by da_in to ensure
+            #         that the variable appears at first place. This is required to avoid
+            #         that var_tar2in becomes a predeictand when slicing takes place in tf_split
+            da_in = xr.concat([da_tar.sel({"variables": var_tar2in}), da_in], "variables")
 
         varnames_tar = da_tar["variables"].values
 
@@ -372,19 +398,20 @@ def get_dataset_filename(datadir: str, dataset_name: str, subset: str, laugmente
 
 
 class StreamMonthlyNetCDF(object):
-    def __init__(self, datadir, patt, nfiles_merge: int, sample_dim: str = "time", selected_predictors: List = None,
-                 selected_predictands: List = None, var_tar2in: str = None, norm_dims: List = None, norm_obj=None,
+    def __init__(self, datadir, patt, nfiles_merge: int, selected_predictands: List, sample_dim: str = "time",
+                 selected_predictors: List = None, var_tar2in: str = None, norm_dims: List = None, norm_obj=None,
                  nworkers: int = 10):
         """
         Class object providing all methods to create a TF dataset that iterates over a set of (monthly) netCDF-files
         rather than loading all into memory. Instead, only a subset of all netCDF-files is loaded into memory.
-        Furthermore, the class attributes provide key information on the handled dataset.
+        Furthermore, the class attributes provide key information on the handled dataset
         :param datadir: directory where set of netCDF-files are located
         :param patt: filename pattern to allow globbing for netCDF-files
         :param nfiles_merge: number of files that will be loaded into memory (corresponding to one dataset subset)
-        :param sample_dim: name of dimension in the data over which sampling should be performed
-        :param selected_predictors: list of predictor variable names to be obtained
         :param selected_predictands: list of predictand variables names to be obtained
+        :param sample_dim: name of dimension in the data over which sampling should be performed
+        :param selected_predictors: list of predictor variable names to be obtained, pass None
+                                    if all vars with suffix _in should be chosen
         :param var_tar2in: predictand (target) variable that can be inputted as well
                           (e.g. static variables known a priori such as the surface topography)
         :param norm_dims: list of dimensions over which data will be normalized
@@ -399,6 +426,7 @@ class StreamMonthlyNetCDF(object):
         self.nfiles2merge = nfiles_merge
         self.nfiles_merged = int(self.nfiles / self.nfiles2merge)
         self.samples_merged = self.get_samples_per_merged_file()
+        self.varnames_list = self.get_all_varnames()
         print(f"Data subsets will comprise {self.samples_merged} samples.")
         self.predictor_list = selected_predictors
         self.predictand_list = selected_predictands
@@ -437,7 +465,12 @@ class StreamMonthlyNetCDF(object):
     def getitems(self, indices):
         da_now = self.data_now.isel({self.sample_dim: indices}).to_array("variables")
         if self.var_tar2in is not None:
-            da_now = xr.concat([da_now, da_now.sel({"variables": self.var_tar2in})], dim="variables")
+            # NOTE: * The order of the following operation must be the same as in make_tf_dataset_allmem
+            #       * The following operation order must concatenate var_tar2in by da_in to ensure
+            #         that the variable appears at first place. This is required to avoid
+            #         that var_tar2in becomes a predeictand when slicing takes place in tf_split
+            da_now = xr.concat([da_now.sel({"variables": self.var_tar2in}), da_now], dim="variables")
+
         return da_now.transpose(..., "variables")
 
     def get_dataset_size(self):
@@ -529,8 +562,9 @@ class StreamMonthlyNetCDF(object):
     @predictor_list.setter
     def predictor_list(self, selected_predictors: List):
         """
-        Initalizes predictor list. In case that selected_predictors is set to None, all variables with suffix `_in` in their names are selected.
-        In case that a list of selected_predictors is parsed, their availability is checked.
+        Initalizes predictor list. In case that selected_predictors is set to None, all variables with suffix `_in`
+        in their names are selected.
+        In case that a list of selected_predictors is parsed, their availability is checked
         :param selected_predictors: list of predictor variables or None
         """
         self._predictor_list = self.check_and_choose_vars(selected_predictors, "_in")
@@ -541,22 +575,28 @@ class StreamMonthlyNetCDF(object):
 
     @predictand_list.setter
     def predictand_list(self, selected_predictands: List):
+        """
+        Similar to predictor_list-setter, but does not allow for parsing None.
+        """
+        assert isinstance(selected_predictands, list), "Selected predictands must be a list of variable names"
         self._predictand_list = self.check_and_choose_vars(selected_predictands, "_tar")
 
+    def get_all_varnames(self):
+        ds_test = xr.open_dataset(self.file_list[0])
+        return list(ds_test.variables)
+
     def check_and_choose_vars(self, var_list: List[str], suffix: str = "*"):
         """
-        Checks list of variables for availability or retrieves all variables named with a given suffix (for var_list = None)
+        Checks list of variables for availability or retrieves all variables named with a given suffix
+        (for var_list = None)
         :param var_list: list of predictor variables or None
-        :param suffix: optional suffix of variables to selected. Only effective if var_list is None.
+        :param suffix: optional suffix of variables to selected. Only effective if var_list is None
         :return selected_vars: list of selected variables
         """
-        ds_test = xr.open_dataset(self.file_list[0])
-        all_vars = list(ds_test.variables)
-
         if var_list is None:
-            selected_vars = [var for var in all_vars if var.endswith(suffix)]
+            selected_vars = [var for var in self.varnames_list if var.endswith(suffix)]
         else:
-            stat_list = [var in all_vars for var in var_list]
+            stat_list = [var in self.varnames_list for var in var_list]
             if all(stat_list):
                 selected_vars = var_list
             else:
@@ -592,9 +632,9 @@ class StreamMonthlyNetCDF(object):
     def read_netcdf(self, set_ind):
         set_ind = tf.keras.backend.get_value(set_ind)
         set_ind = int(str(set_ind).lstrip("b'").rstrip("'"))
-        set_ind = int(set_ind%self.nfiles_merged)
+        set_ind = int(set_ind % self.nfiles_merged)
         file_list_now = self.file_list_random[set_ind * self.nfiles2merge:(set_ind + 1) * self.nfiles2merge]
-        il = int(self.iload_next%2)
+        il = int(self.iload_next % 2)
         # read the normalized data into memory
         # ds_now = xr.open_mfdataset(list(file_list_now), decode_cf=False, data_vars=self.all_vars,
         #                           preprocess=partial(self._preprocess_ds, data_norm=self.data_norm),
@@ -623,8 +663,6 @@ class StreamMonthlyNetCDF(object):
             # free memory
             free_mem([ds_add, add_samples, istart])
 
-        # free memory
-        free_mem([nsamples])
         self.data_loaded[il] = data_now
         # timing
         t_read = timer() - t0
@@ -632,6 +670,8 @@ class StreamMonthlyNetCDF(object):
         self.ds_proc_size += data_now.nbytes
         print(f"Dataset #{set_ind:d} ({il+1:d}/2) reading time: {t_read:.2f}s.")
         self.iload_next = il + 1
+        # free memory
+        free_mem([nsamples, t_read, data_now])
 
         return il
 
diff --git a/downscaling_ap5/main_scripts/main_postprocess.py b/downscaling_ap5/main_scripts/main_postprocess.py
index 4afe089f7dc24bd6bc6777a720b377125e9d7456..0e60fbd7fbf19a05c6e0c88cb65bc4b66cceca43 100644
--- a/downscaling_ap5/main_scripts/main_postprocess.py
+++ b/downscaling_ap5/main_scripts/main_postprocess.py
@@ -42,8 +42,11 @@ def main(parser_args):
     model_dir, plt_dir, norm_dir, model_type = get_model_info(model_base, parser_args.output_base_dir,
                                                               parser_args.exp_name, parser_args.last,
                                                               parser_args.model_type)
-    # create logger handlers
+
+    # create output-directory and set name of netCDF-file to store inference data
     os.makedirs(plt_dir, exist_ok=True)
+    ncfile_out = os.path.join(plt_dir, "postprocessed_ds_test.nc")
+    # create logger handlers
     logfile = os.path.join(plt_dir, f"postprocessing_{parser_args.exp_name}.log")
     if os.path.isfile(logfile): os.remove(logfile)
     fh = logging.FileHandler(logfile)
@@ -98,7 +101,11 @@ def main(parser_args):
     norm = ZScore(ds_dict["norm_dims"])
     norm.read_norm_from_file(js_norm)
 
+    tar_varname = ds_dict["predictands"][0]
+    logger.info(f"Variable {tar_varname} serves as ground truth data.")
+
     with xr.open_dataset(fdata_test) as ds_test:
+        ground_truth = ds_test[tar_varname].astype("float32", copy=False)
         ds_test = norm.normalize(ds_test)
 
     # prepare training and validation data
@@ -107,14 +114,11 @@ def main(parser_args):
 
     da_test = HandleDataClass.reshape_ds(ds_test.astype("float32", copy=False))
     tfds_test = HandleDataClass.make_tf_dataset_allmem(da_test.astype("float32", copy=True), ds_dict["batch_size"],
-                                                       lshuffle=False, var_tar2in=ds_dict["var_tar2in"],
+                                                       ds_dict["predictands"], lshuffle=False, var_tar2in=ds_dict["var_tar2in"],
                                                        named_targets=named_targets, lrepeat=False, drop_remainder=False)
 
     # perform normalization
-    da_test_in, da_test_tar = HandleDataClass.split_in_tar(da_test)
-    tar_varname = da_test_tar['variables'].values[0]
-    ground_truth = ds_test[tar_varname].astype("float32", copy=False)
-    logger.info(f"Variable {tar_varname} serves as ground truth data.")
+    da_test_in, da_test_tar = HandleDataClass.split_in_tar(da_test, predictands=ds_dict["predictands"])
 
     # start inference
     logger.info(f"Preparation of test dataset finished after {timer() - t0_preproc:.2f}s. " +
@@ -135,14 +139,17 @@ def main(parser_args):
         # no slicing required
         y_pred = xr.DataArray(y_pred_trans.squeeze(), coords=coords, dims=dims)
     # perform denormalization
-    y_pred, ground_truth = norm.denormalize(y_pred.squeeze(), varname=tar_varname), norm.denormalize(ground_truth.squeeze(), varname=tar_varname)
+    y_pred = norm.denormalize(y_pred.squeeze(), varname=tar_varname)
+
+    # write inference data to netCDf
+    logger.info(f"Write inference data to netCDF-file '{ncfile_out}'")
+    ground_truth.name, y_pred.name = f"{tar_varname}_ref", f"{tar_varname}_fcst"
+    ds = xr.Dataset(xr.Dataset.merge(y_pred.to_dataset(), ground_truth.to_dataset()))
+    ds.to_netcdf(ncfile_out)
 
     # start evaluation
     logger.info(f"Output data on test dataset successfully processed in {timer()-t0_train:.2f}s. Start evaluation...")
 
-    # create plot directory if required
-    os.makedirs(plt_dir, exist_ok=True)
-
     # instantiate score engine for time evaluation (i.e. hourly time series of evalutaion metrics)
     score_engine = Scores(y_pred, ground_truth, ds_dict["norm_dims"][1:])
 
diff --git a/downscaling_ap5/main_scripts/main_train.py b/downscaling_ap5/main_scripts/main_train.py
index 45c5bc52a0ffd813dba3cc5e98a959d699291576..66970190c0e00b826422ada7c7ae0ec09ff5b4a4 100644
--- a/downscaling_ap5/main_scripts/main_train.py
+++ b/downscaling_ap5/main_scripts/main_train.py
@@ -9,7 +9,7 @@ Driver-script to train downscaling models.
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2022-10-06"
-__update__ = "2023-03-10"
+__update__ = "2023-04-17"
 
 import os
 import argparse
@@ -20,10 +20,11 @@ import json as js
 from timeit import default_timer as timer
 import numpy as np
 import xarray as xr
+from tensorflow.keras.utils import plot_model
 from all_normalizations import ZScore
 from model_utils import ModelEngine, TimeHistory, handle_opt_utils, get_loss_from_history
 from handle_data_class import HandleDataClass, get_dataset_filename
-from other_utils import print_gpu_usage, print_cpu_usage, copy_filelist
+from other_utils import free_mem, print_gpu_usage, print_cpu_usage, copy_filelist
 from benchmark_utils import BenchmarkCSV, get_training_time_dict
 
 
@@ -31,8 +32,7 @@ from benchmark_utils import BenchmarkCSV, get_training_time_dict
 # * d_steps must be parsed with hparams_dict as model is uninstantiated at this point and thus no default parameters
 #   are available
 # * flag named_targets must be set to False in hparams_dict for WGAN to work with U-Net 
-# * ensure that dataset defaults are set
-# * customized choice on predictors and predictands missing
+# * ensure that dataset defaults are set (such as predictands for WGAN)
 # * replacement of manual benchmarking by JUBE-benchmarking to avoid duplications
 
 def main(parser_args):
@@ -82,18 +82,21 @@ def main(parser_args):
     # training data
     print("Start preparing training data...")
     t0_train = timer()
+    varnames_tar = list(ds_dict["predictands"])
     fname_or_patt_train = get_dataset_filename(datadir, dataset, "train", ds_dict.get("laugmented", False))
 
     # if fname_or_patt_train is a filename (string without wildcard), all data will be loaded into memory
     # if fname_or_patt_train is a filename pattern (string with wildcard), the TF-dataset will iterate over subsets of
     # the dataset
     if "*" in fname_or_patt_train:
-        ds_obj, tfds_train = HandleDataClass.make_tf_dataset_dyn(datadir, fname_or_patt_train, bs_train, nepochs, 30,
-                                                                 var_tar2in=ds_dict["var_tar2in"], norm_obj=data_norm,
-                                                                 norm_dims=norm_dims)
+        ds_obj, tfds_train = HandleDataClass.make_tf_dataset_dyn(datadir, fname_or_patt_train, bs_train, nepochs,
+                                                                 30, ds_dict["predictands"],
+                                                                 predictors=ds_dict.get("predictors", None),
+                                                                 var_tar2in=ds_dict["var_tar2in"],
+                                                                 named_targets=named_targets,
+                                                                 norm_obj=data_norm, norm_dims=norm_dims)
         data_norm = ds_obj.data_norm
         nsamples, shape_in = ds_obj.nsamples, (*ds_obj.data_dim[::-1], ds_obj.n_predictors)
-        varnames_tar = list(ds_obj.predictand_list) if named_targets else None
         tfds_train_size = ds_obj.dataset_size
     else:
         ds_train = xr.open_dataset(fname_or_patt_train)
@@ -104,10 +107,11 @@ def main(parser_args):
             data_norm = ZScore(ds_dict["norm_dims"])
 
         da_train = data_norm.normalize(da_train)
-        tfds_train = HandleDataClass.make_tf_dataset_allmem(da_train, bs_train, var_tar2in=ds_dict["var_tar2in"],
+        tfds_train = HandleDataClass.make_tf_dataset_allmem(da_train, bs_train, ds_dict["predictands"],
+                                                            predictors=ds_dict.get("predictors", None),
+                                                            var_tar2in=ds_dict["var_tar2in"],
                                                             named_targets=named_targets)
         nsamples, shape_in = da_train.shape[0], tfds_train.element_spec[0].shape[1:].as_list()
-        varnames_tar = list(tfds_train.element_spec[1].keys()) if named_targets else None
         tfds_train_size = da_train.nbytes
 
     if write_norm:
@@ -123,12 +127,13 @@ def main(parser_args):
         ds_val = data_norm.normalize(ds_val)
     da_val = HandleDataClass.reshape_ds(ds_val)
 
-    tfds_val = HandleDataClass.make_tf_dataset_allmem(da_val.astype("float32", copy=True), ds_dict["batch_size"], lshuffle=True,
-                                                      var_tar2in=ds_dict["var_tar2in"], named_targets=named_targets)
+    tfds_val = HandleDataClass.make_tf_dataset_allmem(da_val.astype("float32", copy=True), ds_dict["batch_size"],
+                                                      ds_dict["predictands"], predictors=ds_dict.get("predictors", None),
+                                                      lshuffle=True, var_tar2in=ds_dict["var_tar2in"],
+                                                      named_targets=named_targets)
     
     # clean up to save some memory
-    del ds_val
-    gc.collect()
+    free_mem([ds_val, da_val])
 
     tval_load = timer() - t0_val
     print(f"Validation data preparation time: {tval_load:.2f}s.")
@@ -142,13 +147,18 @@ def main(parser_args):
         print(f"Data loading time: {ttrain_load:.2f}s.")
 
     # instantiate model
-    model = model_instance(shape_in, hparams_dict, model_savedir, parser_args.exp_name)
-    model.varnames_tar = varnames_tar
+    model = model_instance(shape_in, varnames_tar, hparams_dict, model_savedir, parser_args.exp_name)
 
     # get optional compile options and compile
     compile_opts = handle_opt_utils(model, "get_compile_opts")
     model.compile(**compile_opts)
 
+    # copy configuration and normalization JSON-file to model-directory (incl. renaming)
+    filelist, filelist_new = [parser_args.conf_ds.name, parser_args.conf_md.name], [f"config_ds_{dataset}.json", f"config_{parser_args.model}.json"]
+    if not write_norm:
+        filelist.append(js_norm), filelist_new.append(os.path.basename(js_norm))
+    copy_filelist(filelist, model_savedir, filelist_new)
+
     # train model
     time_tracker = TimeHistory()
     steps_per_epoch = int(np.ceil(nsamples / ds_dict["batch_size"]))
@@ -177,6 +187,12 @@ def main(parser_args):
     os.makedirs(model_savedir, exist_ok=True)
     model.save(filepath=model_savedir)
 
+    if callable(getattr(model, "plot_model", False)):
+        model.plot_model(model_savedir, show_shapes=True)
+    else:
+        plot_model(model, os.path.join(model_savedir, f"plot_{parser_args.exp_name}.png"),
+                   show_shapes=True)
+
     # final timing
     tend = timer()
     saving_time = tend - t0_save
@@ -219,11 +235,6 @@ def main(parser_args):
 
     print("Finished job at {0}".format(dt.strftime(dt.now(), "%Y-%m-%d %H:%M:%S")))
 
-    # copy configuration and normalization JSON-file to output-directory
-    filelist = [parser_args.conf_ds.name, parser_args.conf_md.name]
-    if not write_norm: filelist.append(js_norm)
-    copy_filelist(filelist, model_savedir)
-
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
diff --git a/downscaling_ap5/models/model_utils.py b/downscaling_ap5/models/model_utils.py
index d4998e71307fe06c1f3f96336941afdd3fbb4551..5011e9cdd2eaeab3d6bdf965ed80d6c69b86d325 100644
--- a/downscaling_ap5/models/model_utils.py
+++ b/downscaling_ap5/models/model_utils.py
@@ -50,13 +50,14 @@ class ModelEngine(object):
         else:
             self.model = self.known_models[self.modelname]
 
-    def __call__(self, shape_in, hparams_dict, save_dir, exp_name, **kwargs):
+    def __call__(self, shape_in, varnames_tar, hparams_dict, save_dir, exp_name, **kwargs):
         """
         Instantiate the model with some required arguments.
         """
         model_list = to_list(self.model)
         target_model = model_list[0]
-        model_args = {"shape_in": shape_in, "hparams": hparams_dict, "exp_name": exp_name, "savedir": save_dir, **kwargs}
+        model_args = {"shape_in": shape_in, "varnames_tar": varnames_tar, "hparams": hparams_dict,
+                      "exp_name": exp_name, "savedir": save_dir, **kwargs}
 
         try:
             if len(model_list) == 1:
diff --git a/downscaling_ap5/models/unet_model.py b/downscaling_ap5/models/unet_model.py
index a96e3ad728a354c09008d34d61ea7152c10f0577..60ee0ffdc30929bd822cd352fe5f747e8c002655 100644
--- a/downscaling_ap5/models/unet_model.py
+++ b/downscaling_ap5/models/unet_model.py
@@ -9,7 +9,7 @@ Methods to set-up U-net models incl. its building blocks.
 __author__ = "Michael Langguth"
 __email__ = "m.langguth@fz-juelich.de"
 __date__ = "2021-XX-XX"
-__update__ = "2022-11-25"
+__update__ = "2023-05-06"
 
 # import modules
 import os
@@ -115,12 +115,13 @@ def decoder_block(inputs, skip_features, num_filters, kernel: tuple = (3, 3), st
 
 
 # The particular U-net
-def sha_unet(input_shape: tuple, channels_start: int = 56, z_branch: bool = False,
+def sha_unet(input_shape: tuple, n_predictands_dyn: int, channels_start: int = 56, z_branch: bool = False,
              concat_out: bool = False, tar_channels=["output_dyn", "output_z"]) -> Model:
     """
     Builds up U-net model architecture adapted from Sha et al., 2020 (see https://doi.org/10.1175/JAMC-D-20-0057.1).
     :param input_shape: shape of input-data
     :param channels_start: number of channels to use as start in encoder blocks
+    :param n_predictands: number of target variables (dynamic output variables)
     :param z_branch: flag if z-branch is used.
     :param tar_channels: name of output/target channels (needed for associating losses during compilation)
     :param concat_out: boolean if output layers will be concatenated (disables named target channels!)
@@ -141,7 +142,7 @@ def sha_unet(input_shape: tuple, channels_start: int = 56, z_branch: bool = Fals
     d2 = decoder_block(d1, s2, channels_start * 2)
     d3 = decoder_block(d2, s1, channels_start)
 
-    output_dyn = Conv2D(1, (1, 1), kernel_initializer="he_normal", name=tar_channels[0])(d3)
+    output_dyn = Conv2D(n_predictands_dyn, (1, 1), kernel_initializer="he_normal", name=tar_channels[0])(d3)
     if z_branch:
         print("Use z_branch...")
         output_static = Conv2D(1, (1, 1), kernel_initializer="he_normal", name=tar_channels[1])(d3)
@@ -161,15 +162,17 @@ class UNET(keras.Model):
     U-Net submodel class:
     This subclass takes a U-Net implemented using Keras functional API as input to the instanciation.
     """
-    def __init__(self, unet_model: keras.Model, shape_in: List, hparams: dict, savedir: str,
+    def __init__(self, unet_model: keras.Model, shape_in: List, varnames_tar: List, hparams: dict, savedir: str,
                  exp_name: str = "unet_model"):
 
         super(UNET, self).__init__()
 
         self.unet = unet_model
         self.shape_in = shape_in
-        self.varnames_tar = None                            # yet, dirty approach: to be set after instantiating from main_train.py
+        self.varnames_tar = varnames_tar
         self.hparams = UNET.get_hparams_dict(hparams)
+        self.n_predictands = len(varnames_tar)                      # number of predictands
+        self.n_predictands_dyn = self.n_predictands - 1 if self.hparams["z_branch"] else self.n_predictands
         if self.hparams["l_embed"]:
             raise ValueError("Embedding is not implemented yet.")
         self.modelname = exp_name
@@ -183,6 +186,7 @@ class UNET(keras.Model):
         See https://stackoverflow.com/questions/65318036/is-it-possible-to-use-the-tensorflow-keras-functional-api-train_unet-model-err.6387845within-a-subclassed-mo
         for a reference how a model based on Keras functional API has to be integrated into a subclass.
         """
+        print(**kwargs)
         return self.unet(inputs, **kwargs)
 
     def get_compile_opts(self):
@@ -207,11 +211,12 @@ class UNET(keras.Model):
 
     def compile(self, **kwargs):
         # instantiate model
-        if self.varnames_tar:
+        if self.hparams["z_branch"]:     # model has named branches (see also opt_dict in get_compile_opts)
             tar_channels = [f"{var}" for var in self.varnames_tar]
-            self.unet = self.unet(self.shape_in, z_branch=self.hparams["z_branch"], tar_channels=tar_channels)
+            self.unet = self.unet(self.shape_in, self.n_predictands_dyn, z_branch=True,
+                                  concat_out= False, tar_channels=tar_channels)
         else:
-            self.unet = self.unet(self.shape_in, z_branch=self.hparams["z_branch"], concat_out=True)
+            self.unet = self.unet(self.shape_in, self.n_predictands_dyn, z_branch=False)
 
         return self.unet.compile(**kwargs)
        # return super(UNET, self).compile(**kwargs)
@@ -310,7 +315,7 @@ class UNET(keras.Model):
         hparams_dict = {"batch_size": 32, "lr": 5.e-05, "nepochs": 70, "z_branch": True, "loss_func": "mae",
                         "loss_weights": [1.0, 1.0], "lr_decay": False, "decay_start": 5, "decay_end": 30,
                         "lr_end": 1.e-06, "l_embed": False, "ngf": 56, "optimizer": "adam", "lscheduled_train": True,
-                        "var_tar2in": ""}
+                        "var_tar2in": "", "n_predictands": 1}
 
         return hparams_dict
 
diff --git a/downscaling_ap5/models/wgan_model.py b/downscaling_ap5/models/wgan_model.py
index f26f221176132624d087952dca4ff00825bef1df..1e14df96c255b6fb96798099b3f2ba0176ad3eec 100644
--- a/downscaling_ap5/models/wgan_model.py
+++ b/downscaling_ap5/models/wgan_model.py
@@ -15,6 +15,7 @@ import tensorflow as tf
 import tensorflow.keras as keras
 from tensorflow.keras import backend as K
 from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint, EarlyStopping
+from tensorflow.keras.utils import plot_model as k_plot_model
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.platform import tf_logging as logging
 
@@ -69,13 +70,14 @@ class WGAN(keras.Model):
     Class for Wassterstein GAN models
     """
 
-    def __init__(self, generator: keras.Model, critic: keras.Model, shape_in: List, hparams: dict, savedir: str,
-                 exp_name: str = "wgan_model"):
+    def __init__(self, generator: keras.Model, critic: keras.Model, shape_in: List, varnames_tar: List, hparams: dict,
+                 savedir: str, exp_name: str = "wgan_model"):
         """
         Initiate Wasserstein GAN model
         :param generator: A generator model returning a data field
         :param critic: A critic model which returns a critic scalar on the data field
         :param shape_in: shape of input data
+        :param varnames_tar: list of target variables/predictands (incl. static variables, e.g. for z_branch)
         :param hparams: dictionary of hyperparameters
         :param exp_name: name of the WGAN experiment
         :param savedir: directory where checkpointed model will be saved
@@ -88,16 +90,19 @@ class WGAN(keras.Model):
         self.hparams = WGAN.get_hparams_dict(hparams)
         if self.hparams["l_embed"]:
             raise ValueError("Embedding is not implemented yet.")
+        self.n_predictands = len(varnames_tar)                      # number of predictands
+        self.n_predictands_dyn = self.n_predictands - 1 if self.hparams["z_branch"] else self.n_predictands
+        print(f"Dynamic predictands: {self.n_predictands_dyn}, Predictands: {self.n_predictands}")
         self.modelname = exp_name
         if not os.path.isdir(savedir):
             os.makedirs(savedir, exist_ok=True)
         self.savedir = savedir
 
         # instantiate submodels
-        tar_shape = (*self.shape_in[:-1], 1)   # critic only accounts for 1st channel (should be the downscaling target)
-        # instantiate models
-        self.generator = self.generator(self.shape_in, channels_start=self.hparams["ngf"],
-                                        z_branch=self.hparams["z_branch"], concat_out=True)
+        # instantiate model components (generator and critci)
+        self.generator = self.generator(self.shape_in, self.n_predictands, channels_start=self.hparams["ngf"],
+                                        concat_out=True, z_branch=self.hparams["z_branch"])
+        tar_shape = (*self.shape_in[:-1], self.n_predictands_dyn)   # critic only accounts for dynamic predictands
         self.critic = self.critic(tar_shape)
 
         # Unused attribute, but introduced for joint driver script with U-Net; to be solved with customized target vars
@@ -190,16 +195,19 @@ class WGAN(keras.Model):
         for i in range(self.hparams["d_steps"]):
             with tf.GradientTape() as tape_critic:
                 ist, ie = i * self.hparams["batch_size"], (i + 1) * self.hparams["batch_size"]
-                # critic only operates on first channel
-                predictands_critic = tf.expand_dims(predictands[ist:ie, :, :, 0], axis=-1)
+                # critic only operates on predictand channels
+                if self.n_predictands_dyn > 1:
+                    predictands_critic = predictands[ist:ie, :, :, 0:self.n_predictands_dyn]
+                else:
+                    predictands_critic = tf.expand_dims(predictands[ist:ie, :, :, 0], axis=-1)
                 # generate (downscaled) data
                 gen_data = self.generator(predictors[ist:ie, ...], training=True)
                 # calculate critics for both, the real and the generated data
-                critic_gen = self.critic(gen_data[..., 0], training=True)
+                critic_gen = self.critic(gen_data[..., 0:self.n_predictands_dyn], training=True)
                 critic_gt = self.critic(predictands_critic, training=True)
                 # calculate the loss (incl. gradient penalty)
                 c_loss = WGAN.critic_loss(critic_gt, critic_gen)
-                gp = self.gradient_penalty(predictands_critic, gen_data[..., 0:1])
+                gp = self.gradient_penalty(predictands_critic, gen_data[..., 0:self.n_predictands_dyn])
 
                 d_loss = c_loss + self.hparams["gp_weight"] * gp
 
@@ -212,7 +220,7 @@ class WGAN(keras.Model):
             # generate (downscaled) data
             gen_data = self.generator(predictors[-self.hparams["batch_size"]:, :, :, :], training=True)
             # get the critic and calculate corresponding generator losses (critic and reconstruction loss)
-            critic_gen = self.critic(gen_data[..., 0], training=True)
+            critic_gen = self.critic(gen_data[..., 0:self.n_predictands_dyn], training=True)
             cg_loss = WGAN.critic_gen_loss(critic_gen)
             rloss = self.recon_loss(predictands[-self.hparams["batch_size"]:, :, :, :], gen_data)
 
@@ -252,7 +260,7 @@ class WGAN(keras.Model):
         :return: gradient penalty
         """
         # get mixture of generated and ground truth data
-        alpha = tf.random.normal([self.hparams["batch_size"], 1, 1, 1], 0., 1.)
+        alpha = tf.random.normal([self.hparams["batch_size"], 1, 1, self.n_predictands_dyn], 0., 1.)
         mix_data = real_data + alpha * (gen_data - real_data)
 
         with tf.GradientTape() as gp_tape:
@@ -270,16 +278,21 @@ class WGAN(keras.Model):
     def recon_loss(self, real_data, gen_data):
         # initialize reconstruction loss
         rloss = 0.
-        # get number of output heads (=2 if z_branch is activated)
-        n = 1
-        if self.hparams["z_branch"]:
-            n = 2
         # get MAE for all output heads
-        for i in range(n):
+        for i in range(self.hparams["n_predictands"]):
             rloss += tf.reduce_mean(tf.abs(tf.squeeze(gen_data[..., i]) - real_data[..., i]))
 
         return rloss
 
+    def plot_model(self, save_dir, **kwargs):
+        """
+        Plot generator and critic model separately.
+        :param save_dir: directory under which plots will be saved
+        :param kwargs: All keyword arguments valid for tf.keras.utils.plot_model
+        """
+        k_plot_model(self.generator, os.path.join(save_dir, f"plot_{self.modelname}_generator.png"), **kwargs)
+        k_plot_model(self.critic, os.path.join(save_dir, f"plot_{self.modelname}_critic.png"), **kwargs)
+
     def save(self, filepath: str, overwrite: bool = True, include_optimizer: bool = True, save_format: str = None,
              signatures=None, options=None, save_traces: bool = True):
         """
@@ -362,7 +375,7 @@ class WGAN(keras.Model):
         hparams_dict = {"batch_size": 32, "lr_gen": 1.e-05, "lr_critic": 1.e-06, "nepochs": 50, "z_branch": False,
                         "lr_decay": False, "decay_start": 5, "decay_end": 10, "lr_gen_end": 1.e-06, "l_embed": False,
                         "ngf": 56, "d_steps": 5, "recon_weight": 1000., "gp_weight": 10., "optimizer": "adam", 
-                        "lscheduled_train": True, "var_tar2in": ""}
+                        "lscheduled_train": True, "var_tar2in": "", "n_predictands": 2}
 
         return hparams_dict
 
diff --git a/downscaling_ap5/postprocess/postprocess.py b/downscaling_ap5/postprocess/postprocess.py
index 1f33315df118a86af976a4a26a46454ec3f0b7d0..bac9c6ab6ce2ca33c88344537ed499a5d8fb883c 100644
--- a/downscaling_ap5/postprocess/postprocess.py
+++ b/downscaling_ap5/postprocess/postprocess.py
@@ -36,11 +36,11 @@ def get_model_info(model_base, output_base: str, exp_name: str, bool_last: bool
                              os.path.join(output_base, model_name)
         norm_dir = model_base
         model_type = "wgan"
-    elif "unet" in exp_name:
+    elif "unet" in exp_name or "deepru" in exp_name:
         func_logger.debug(f"U-Net-modeltype detected.")
         model_dir, plt_dir = model_base, os.path.join(output_base, model_name)
         norm_dir = model_dir
-        model_type = "unet"
+        model_type = "unet" if "unet" in exp_name else "deepru"
     else:
         func_logger.debug(f"Model type could not be inferred from experiment name. Try my best by defaulting...")
         if not model_type:
diff --git a/downscaling_ap5/test_scripts/demo_tfdataset_ap5.py b/downscaling_ap5/test_scripts/demo_tfdataset_ap5.py
index 3c073a53d6c5f6cc71b3e4863f8d30afe42d264b..f01e84b2225229096c6339be7eed67354a6e904b 100644
--- a/downscaling_ap5/test_scripts/demo_tfdataset_ap5.py
+++ b/downscaling_ap5/test_scripts/demo_tfdataset_ap5.py
@@ -1,4 +1,3 @@
-<<<<<<< HEAD
 # SPDX-FileCopyrightText: 2023 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
 #
 # SPDX-License-Identifier: MIT
@@ -74,381 +73,6 @@ def main():
     # some statistics on memory usage
     print_gpu_usage("Final GPU memory: ")
     print_cpu_usage("Final CPU memory: ")
-
-
-if __name__ == "__main__":
-    main()
-
-=======
-import os, sys, glob
-import argparse
-from typing import List, Union
-from operator import itemgetter
-from functools import partial
-import re
-import gc
-import random
-from timeit import default_timer as timer
-import numpy as np
-import xarray as xr
-import dask
-from multiprocessing import Pool as ThreadPool
-import tensorflow as tf
-import tensorflow.keras as keras
-
-da_or_ds = Union[xr.DataArray, xr.Dataset]
-
-def main():
-    parser = argparse.ArgumentParser("Program that test the MAELSTROM AP5 data pipeline")
-    parser.add_argument("--datadir", "-d", dest="datadir", type=str,
-                       default="/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files_copy/", 
-                       help="Directory where monthly netCDF-files are stored")
-    parser.add_argument("--file_pattern", "-f", dest="file_patt", type=str, default="downscaling_tier2_train_*.nc", help="Filename pattern to glob netCDF-files")    
-    parser.add_argument("--nfiles_load", "-n", default=30, type=int, dest="nfiles_load",
-                        help="Number of netCDF-files to load into memory (2x with prefetching).")
-    parser.add_argument("--lprefetch", "-lprefetch", dest="lprefetch", action="store_true",
-                       help="Enable prefetching.")
-    parser.add_argument("--batch_size", "-b", dest="batch_size", default=192, type=int, 
-                        help="Batch size for TF dataset.")
-    args = parser.parse_args()
-    
-    # get data handler
-    ds_obj = StreamMonthlyNetCDF(args.datadir, args.file_patt, nfiles_merge=args.nfiles_load,
-                                 norm_dims=["time", "rlat", "rlon"])
-    
-    # set-up TF dataset
-    ## map-funcs
-    tf_read_nc = lambda ind_set: tf.py_function(ds_obj.read_netcdf, [ind_set], tf.int64)
-    tf_choose_data = lambda il: tf.py_function(ds_obj.choose_data, [il], tf.bool)
-    tf_getdata = lambda i: tf.numpy_function(ds_obj.getitems, [i], tf.float32)
-    tf_split = lambda arr: (arr[..., 0:-ds_obj.n_predictands], arr[..., -ds_obj.n_predictands:])
-    
-    ## process chain
-    tfds = tf.data.Dataset.range(int(ds_obj.nfiles_merged*6*10)).map(tf_read_nc) # 6*10 corresponds to (d_steps + 1)*n_epochs with d_steps=5, n_epochs=10
-    if args.lprefetch:
-        tfds = tfds.prefetch(1)     # .prefetch(1) ensures that one data subset (=n files) is prefetched
-    tfds = tfds.flat_map(lambda x: tf.data.Dataset.from_tensors(x).map(tf_choose_data))
-    tfds = tfds.flat_map(
-        lambda x: tf.data.Dataset.range(ds_obj.samples_merged).shuffle(ds_obj.samples_merged)
-        .batch(args.batch_size, drop_remainder=True).map(tf_getdata))
-
-    tfds = tfds.prefetch(int(ds_obj.samples_merged))
-    tfds = tfds.map(tf_split).repeat()
-    
-    t0 = timer()
-    for i, x in enumerate(tfds):
-        if i == int(ds_obj.nsamples/args.batch_size) - 1:
-            break
-            
-    print(f"Processing one epoch of data lasted {timer() - t0:.1f} seconds.")
-    
-
-class ZScore(object):
-    """
-    Class for performing zscore-normalization and denormalization.
-    Also computes normalization parameters from data if necessary.
-    """
-    def __init__(self, norm_dims: List):
-        self.method = "zscore"
-        self.norm_dims = norm_dims
-        self.norm_stats = {"mu": None, "sigma": None}
-
-    def get_required_stats(self, data: da_or_ds, **stats):
-        """
-        Get required parameters for z-score normalization. They are either computed from the data
-        or can be parsed as keyword arguments.
-        :param data: the data to be (de-)normalized
-        :param stats: keyword arguments for mean (mu) and standard deviation (std) used for normalization
-        :return (mu, sigma): Parameters for normalization
-        """
-        mu, std = stats.get("mu", self.norm_stats["mu"]), stats.get("sigma", self.norm_stats["sigma"])
-
-        if mu is None or std is None:
-            print("Retrieve mu and sigma from data...")
-            mu, std = data.mean(self.norm_dims), data.std(self.norm_dims)
-            # The following ensure that both parameters are computed in one graph!
-            # This significantly reduces memory footprint as we don't end up having data duplicates
-            # in memory due to multiple graphs (and also seem to enfore usage of data chunks as well)
-            mu, std = dask.compute(mu, std)
-            self.norm_stats = {"mu": mu, "sigma": std}
-        # else:
-        #    print("Mu and sigma are parsed for (de-)normalization.")
-
-        return mu, std
-
-    def normalize(self, data, **stats):
-        """
-        Perform z-score normalization on data. 
-        Either computes the normalization parameters from the data or applies pre-existing ones.
-        :param data: Data array of interest
-        :param mu: mean of data for normalization
-        :param std: standard deviation of data for normalization
-        :return data_norm: normalized data
-        """
-        mu, std = self.get_required_stats(data, **stats)
-        data_norm = (data - mu) / std
-
-        return data_norm
-
-    def denormalize(self, data, **stats):
-        """
-        Perform z-score denormalization on data.
-        :param data: Data array of interest
-        :param mu: mean of data for denormalization
-        :param std: standard deviation of data for denormalization
-        :return data_norm: denormalized data
-        """
-        if self.norm_stats["mu"] is None or self.norm_stats["std"] is None:
-            raise ValueError("Normalization parameters mu and std are unknown.")
-        else:
-            norm_stats = self.get_required_stats(data, **stats)
-        
-        
-        data_denorm = data * norm_stats["std"] + norm_stats["mu"]
-
-        return data_denorm
-
-
-class StreamMonthlyNetCDF(object):
-    """
-    Data handler for monthly netCDF-files which provides methods for setting up 
-    a TF dataset that is too large to fit into memory. 
-    """    
-    def __init__(self, datadir, patt: str, nfiles_merge: int, sample_dim: str = "time", selected_predictors: List = None,
-                 selected_predictands: List = None, var_tar2in: str = None, norm_dims: List = None, norm_obj=None):
-        """
-        Initialize data handler.
-        :param datadir: path to directory where netCDF-files are located
-        :param patt: file name pattern for globbing
-        :param nfiles_merge: number of netCDF-files that get loaded into memory 
-        :param sample_dim: dimension from which samples will be drawn
-        :param selected_predictors: list of predictor variable names
-        :param selected_predictands: list of predictand variable names
-        :param var_tar2in: target variable that should be added to input as well (e.g. surface topography)
-        :param norm_dims: dimenions over which data will be normalized
-        :param norm_obj: normalization object 
-        """
-        self.data_dir = datadir
-        self.file_list = patt
-        self.nfiles = len(self.file_list)
-        self.file_list_random = random.sample(self.file_list, self.nfiles)
-        self.nfiles2merge = nfiles_merge
-        self.nfiles_merged = int(self.nfiles / self.nfiles2merge)
-        self.samples_merged = self.get_samples_per_merged_file()
-        self.predictor_list = selected_predictors
-        self.predictand_list = selected_predictands
-        self.n_predictands, self.n_predictors = len(self.predictand_list), len(self.predictor_list)
-        self.all_vars = self.predictor_list + self.predictand_list
-        self.ds_all = xr.open_mfdataset(list(self.file_list), decode_cf=False, cache=False)  # , parallel=True)
-        self.var_tar2in = var_tar2in
-        if self.var_tar2in is not None:
-            self.n_predictors += len(to_list(self.var_tar2in))
-        self.sample_dim = sample_dim
-        self.nsamples = self.ds_all.dims[sample_dim]
-        self.data_dim = self.get_data_dim()
-        t0 = timer()
-        self.normalization_time = -999.
-        if norm_obj is None:  
-            print("Start computing normalization parameters.")
-            self.data_norm = ZScore(norm_dims)  # TO-DO: Allow for arbitrary normalization
-            self.norm_params = self.data_norm.get_required_stats(self.ds_all)
-            self.normalization_time = timer() - t0
-        else:
-            self.data_norm = norm_obj
-            self.norm_params = norm_obj.norm_stats
-
-        self.data_loaded = [xr.Dataset, xr.Dataset]
-        self.i_loaded = 0
-        self.data_now = None
-        self.pool = ThreadPool(10)        # To-Do: remove hard-coded number of threads (-> support contact)
-
-    def __len__(self):
-        return self.nsamples
-
-    def getitems(self, indices):
-        da_now = self.data_now.isel({self.sample_dim: indices}).to_array("variables")
-        if self.var_tar2in is not None:
-            da_now = xr.concat([da_now, da_now.sel({"variables": self.var_tar2in})], dim="variables")
-        return da_now.transpose(..., "variables")
-
-    def get_data_dim(self):
-        """
-        Retrieve the dimensionality of the data to be handled, i.e. without sample_dim which will be batched in a
-        data stream.
-        :return: tuple of data dimensions
-        """
-        # get existing dimension names and remove sample_dim
-        dimnames = list(self.ds_all.coords)
-        dimnames.remove(self.sample_dim)
-
-        # get the dimensionality of the data of interest
-        all_dims = dict(self.ds_all.dims)
-        data_dim = itemgetter(*dimnames)(all_dims)
-
-        return data_dim
-
-    def get_samples_per_merged_file(self):
-        nsamples_merged = []
-
-        for i in range(self.nfiles_merged):
-            file_list_now = self.file_list_random[i * self.nfiles2merge: (i + 1) * self.nfiles2merge]
-            ds_now = xr.open_mfdataset(list(file_list_now), decode_cf=False)
-            nsamples_merged.append(ds_now.dims["time"])  # To-Do avoid hard-coding
-
-        return max(nsamples_merged)
-
-    @property
-    def data_dir(self):
-        return self._data_dir
-
-    @data_dir.setter
-    def data_dir(self, datadir):
-        if not os.path.isdir(datadir):
-            raise NotADirectoryError(f"Parsed data directory '{datadir}' does not exist.")
-
-        self._data_dir = datadir
-
-    @property
-    def file_list(self):
-        return self._file_list
-
-    @file_list.setter
-    def file_list(self, patt):
-        patt = patt if patt.endswith(".nc") else f"{patt}.nc"
-        files = glob.glob(os.path.join(self.data_dir, patt))
-
-        if not files:
-            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")
-
-        self._file_list = list(
-            np.asarray(sorted(files, key=lambda s: int(re.search(r'\d+', os.path.basename(s)).group()))))
-
-    @property
-    def nfiles2merge(self):
-        return self._nfiles2merge
-
-    @nfiles2merge.setter
-    def nfiles2merge(self, n2merge):
-        #n = find_closest_divisor(self.nfiles, n2merge)
-        #if n != n2merge:
-        #    print(f"{n2merge} is not a divisor of the total number of files. Value is changed to {n}")
-        if self.nfiles%n2merge != 0:
-            raise ValueError(f"Chosen number of files ({n2merge:d}) to read must be a divisor " +
-                             f" of total number of files ({self.nfiles:d}).")
-        
-        self._nfiles2merge = n2merge
-
-    @property
-    def sample_dim(self):
-        return self._sample_dim
-
-    @sample_dim.setter
-    def sample_dim(self, sample_dim):
-        if not sample_dim in self.ds_all.dims:
-            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")
-
-        self._sample_dim = sample_dim
-
-    @property
-    def predictor_list(self):
-        return self._predictor_list
-
-    @predictor_list.setter
-    def predictor_list(self, selected_predictors: List):
-        """
-        Initalizes predictor list. In case that selected_predictors is set to None, all variables with suffix `_in` in their names are selected.
-        In case that a list of selected_predictors is parsed, their availability is checked.
-        :param selected_predictors: list of predictor variables or None
-        """
-        self._predictor_list = self.check_and_choose_vars(selected_predictors, "_in")
-
-    @property
-    def predictand_list(self):
-        return self._predictand_list
-
-    @predictand_list.setter
-    def predictand_list(self, selected_predictands: List):
-        self._predictand_list = self.check_and_choose_vars(selected_predictands, "_tar")
-
-    def check_and_choose_vars(self, var_list: List[str], suffix: str = "*"):
-        """
-        Checks list of variables for availability or retrieves all variables named with a given suffix (for var_list = None)
-        :param var_list: list of predictor variables or None
-        :param suffix: optional suffix of variables to selected. Only effective if var_list is None.
-        :return selected_vars: list of selected variables
-        """
-        ds_test = xr.open_dataset(self.file_list[0])
-        all_vars = list(ds_test.variables)
-
-        if var_list is None:
-            selected_vars = [var for var in all_vars if var.endswith(suffix)]
-        else:
-            stat_list = [var in all_vars for var in var_list]
-            if all(stat_list):
-                selected_vars = var_list
-            else:
-                miss_inds = [i for i, x in enumerate(stat_list) if x]
-                miss_vars = [var_list[i] for i in miss_inds]
-                raise ValueError(f"Could not find the following variables in the dataset: {*miss_vars,}")
-
-        return selected_vars
-
-    @staticmethod
-    def _process_one_netcdf(fname, data_norm, engine: str = "netcdf4", var_list: List = None, **kwargs):
-        with xr.open_dataset(fname, decode_cf=False, engine=engine, **kwargs) as ds_now:
-            if var_list: ds_now = ds_now[var_list]
-            ds_now = StreamMonthlyNetCDF._preprocess_ds(ds_now, data_norm)
-            ds_now = ds_now.load()
-            return ds_now
-
-    @staticmethod
-    def _preprocess_ds(ds, data_norm):
-        ds = data_norm.normalize(ds)
-        return ds.astype("float32")
-
-    def _read_mfdataset(self, files, **kwargs):
-        t0 = timer()
-        # parallel processing of files incl. normalization
-        datasets = self.pool.map(partial(self._process_one_netcdf, data_norm=self.data_norm, **kwargs), files)
-        ds_all = xr.concat(datasets, dim=self.sample_dim)
-        # clean-up
-        del datasets
-        gc.collect()
-        # timing
-        print(f"Reading dataset of {len(files)} files took {timer() - t0:.2f}s.")
-
-        return ds_all
-
-    def read_netcdf(self, set_ind):
-        set_ind = tf.keras.backend.get_value(set_ind)
-        set_ind = int(str(set_ind).lstrip("b'").rstrip("'"))
-        set_ind = int(set_ind%self.nfiles_merged)
-        il = int((self.i_loaded + 1)%2)
-        file_list_now = self.file_list_random[set_ind * self.nfiles2merge:(set_ind + 1) * self.nfiles2merge]
-        # read the normalized data into memory
-        # ds_now = xr.open_mfdataset(list(file_list_now), decode_cf=False, data_vars=self.all_vars,
-        #                           preprocess=partial(self._preprocess_ds, data_norm=self.data_norm),
-        #                           parallel=True).load()
-        self.data_loaded[il] = self._read_mfdataset(file_list_now, var_list=self.all_vars).copy()
-        nsamples = self.data_loaded[il].dims[self.sample_dim]
-        if nsamples < self.samples_merged:
-            t0 = timer()
-            add_samples = self.samples_merged - nsamples
-            add_inds = random.sample(range(nsamples), add_samples)
-            ds_add = self.data_loaded[il].isel({self.sample_dim: add_inds})
-            ds_add[self.sample_dim] = ds_add[self.sample_dim] + 1.
-            self.data_loaded[il] = xr.concat([self.data_loaded[il], ds_add], dim=self.sample_dim)
-            print(f"Appending data with {add_samples:d} samples took {timer() - t0:.2f}s.")
-
-        print(f"Currently loaded dataset has {self.data_loaded[il].dims[self.sample_dim]} samples.")
-
-        return il
-
-    def choose_data(self, il):
-        self.data_now = self.data_loaded[il]
-        self.i_loaded = il
-        return True
     
 if __name__ == "__main__":
     main()
->>>>>>> develop
diff --git a/downscaling_ap5/utils/other_utils.py b/downscaling_ap5/utils/other_utils.py
index 5cd200181e56a00410a6e5a5b200089e5ea408b3..2853df1e88789774b1821e482a81df47f9443ab5 100644
--- a/downscaling_ap5/utils/other_utils.py
+++ b/downscaling_ap5/utils/other_utils.py
@@ -337,22 +337,30 @@ def get_max_memory_usage():
     return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1000
 
 
-def copy_filelist(file_list: List, dest: str, labort: bool = True):
+def copy_filelist(file_list: List, dest_dir: str, file_list_dest: List = None ,labort: bool = True):
     """
     Copy a list of files to another directory
     :param file_list: list of files to copy
-    :param dest: target directory to which files will be copied
+    :param dest_dir: target directory to which files will be copied
     :param labort: flag to trigger raising of an error (if False, only Warning-messages will be printed)
     """
     file_list = to_list(file_list)
-    if not os.path.isdir(dest) and labort:
+    if not os.path.isdir(dest_dir) and labort:
         raise NotADirectoryError(f"Cannot copy to non-existing directory '{dest}'.")
-    elif not os.path.isdir(dest) and not labort:
+    elif not os.path.isdir(dest_dir) and not labort:
         print(f"WARNING: Target directory for copying '{dest}' does not exist. Skip copy process...")
+        return
 
-    for f in file_list:
+    if file_list_dest is None:
+        dest = dest_dir
+    else:
+        assert len(file_list) == len(file_list_dest), f"Length of filelist to copy ({len(file_list)})" + \
+                                                      f" and of filelist at destination ({len(file_list_dest)}) differ."
+        dest = [os.path.join(dest_dir, f_dest) for f_dest in file_list_dest]
+
+    for i, f in enumerate(file_list):
         if os.path.isfile(f):
-            shutil.copy(f, dest)
+            shutil.copy(f, dest[i])
         else:
             if labort:
                 raise FileNotFoundError(f"Could not find file '{f}'. Error will be raised.")