diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index 1862d6734430d42a2d0cda0b199acef97b58bebb..fa55d9d944eb03d6096eea7507045a1904360a1d 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -1,7 +1,6 @@
 __author__ = "Lukas Leufen"
 __date__ = '2020-06-25'
 
-from mlair.helpers.statistics import TransformationClass
 
 DEFAULT_STATIONS = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
 DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
@@ -14,7 +13,6 @@ DEFAULT_START = "1997-01-01"
 DEFAULT_END = "2017-12-31"
 DEFAULT_WINDOW_HISTORY_SIZE = 13
 DEFAULT_OVERWRITE_LOCAL_DATA = False
-DEFAULT_TRANSFORMATION = TransformationClass(inputs_method="standardise", targets_method="standardise")
 DEFAULT_HPC_LOGIN_LIST = ["ju", "hdfmll"]  # ju[wels} #hdfmll(ogin)
 DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"]  # first part of node names for Juwels (jw[comp], hdfmlc(ompute).
 DEFAULT_CREATE_NEW_MODEL = True
diff --git a/mlair/data_handler/data_handler_mixed_sampling.py b/mlair/data_handler/data_handler_mixed_sampling.py
index 19fc26fe78f4aaec034d6593e3b4628b85fc5644..0a7adadb7d3c00ebd10dbff176feb4a8ff6b8d5f 100644
--- a/mlair/data_handler/data_handler_mixed_sampling.py
+++ b/mlair/data_handler/data_handler_mixed_sampling.py
@@ -192,7 +192,7 @@ class DataHandlerSeparationOfScalesSingleStation(DataHandlerMixedSamplingWithFil
         :param dim_name_of_shift: Dimension along shift will be applied
         """
         window = -abs(window)
-        data = self.input_data.data
+        data = self.input_data
         self.history = self.stride(data, dim_name_of_shift, window)
 
     def stride(self, data: xr.DataArray, dim: str, window: int) -> xr.DataArray:
diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index 654f489fab8ee6ed8eb360be54be7c755da061e1..832a643f2af7c6c2f0510fa1c2cf0353c516f67f 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -8,7 +8,7 @@ import datetime as dt
 import logging
 import os
 from functools import reduce
-from typing import Union, List, Iterable, Tuple, Dict
+from typing import Union, List, Iterable, Tuple, Dict, Optional
 
 import numpy as np
 import pandas as pd
@@ -60,7 +60,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
         self.statistics_per_var = statistics_per_var
         self.data_origin = data_origin
         self.do_transformation = transformation is not None
-        self.input_data, self.target_data = self.setup_transformation(transformation)
+        self.input_data, self.target_data = None, None
+        self._transformation = self.setup_transformation(transformation)
 
         self.station_type = station_type
         self.network = network
@@ -117,14 +118,16 @@ class DataHandlerSingleStation(AbstractDataHandler):
 
         :return: history with dimensions datetime, window, Stations, variables.
         """
-        return self.history.transpose("datetime", "window", "Stations", "variables").copy()
+        return self.history.transpose("datetime", "window", "Stations",
+                                      "variables").copy()  # ToDo: remove hardcoded dims
 
     def get_transposed_label(self) -> xr.DataArray:
         """Return label.
 
         :return: label with dimensions datetime*, window*, Stations, variables.
         """
-        return self.label.squeeze("Stations").transpose("datetime", "window").copy()
+        return self.label.squeeze(["Stations", "variables"]).transpose("datetime",
+                                                                       "window").copy()  # ToDo: remove hardcoded dims
 
     def get_X(self, **kwargs):
         return self.get_transposed_history()
@@ -137,10 +140,81 @@ class DataHandlerSingleStation(AbstractDataHandler):
         return coords.rename(index={"station_lon": "lon", "station_lat": "lat"}).to_dict()[str(self)]
 
     def call_transform(self, inverse=False):
-        kwargs = helpers.remove_items(self.input_data.as_dict(), ["data"])
-        self.transform(self.input_data, dim=self.time_dim, inverse=inverse, **kwargs)
-        kwargs = helpers.remove_items(self.target_data.as_dict(), ["data"])
-        self.transform(self.target_data, dim=self.time_dim, inverse=inverse, **kwargs)
+        opts_input = self._transformation[0]
+        self.input_data, opts_input = self.transform_new(self.input_data, dim=self.time_dim, inverse=inverse,
+                                                         opts=opts_input)
+        opts_target = self._transformation[1]
+        self.target_data, opts_target = self.transform_new(self.target_data, dim=self.time_dim, inverse=inverse,
+                                                           opts=opts_target)
+        self._transformation = (opts_input, opts_target)
+
+    def transform_new(self, data_in, dim: Union[str, int] = 0,
+                      inverse: bool = False, opts=None):
+        """
+        Transform data according to given transformation settings.
+
+        This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0
+        and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale
+        (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This
+        method will raise an AssertionError if an internal transform method was already set ('inverse=False') or if the
+        internal transform method, internal mean and internal standard deviation weren't set ('inverse=True').
+
+        :param string/int dim: This param is not used for inverse transformation.
+                | for xarray.DataArray as string: name of dimension which should be standardised
+                | for pandas.DataFrame as int: axis of dimension which should be standardised
+        :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented
+                    yet. This param is not used for inverse transformation.
+        :param inverse: Switch between transformation and inverse transformation.
+        :param mean: Used for transformation (if required by 'method') based on external data. If 'None' the mean is
+                    calculated over the data in this class instance.
+        :param std: Used for transformation (if required by 'method') based on external data. If 'None' the std is
+                    calculated over the data in this class instance.
+        :param min: Used for transformation (if required by 'method') based on external data. If 'None' min_val is
+                    extracted from the data in this class instance.
+        :param max: Used for transformation (if required by 'method') based on external data. If 'None' max_val is
+                    extracted from the data in this class instance.
+
+        :return: xarray.DataArrays or pandas.DataFrames:
+                #. mean: Mean of data
+                #. std: Standard deviation of data
+                #. data: Standardised data
+        """
+
+        def f(data, method, *args):
+            if method == 'standardise':
+                return statistics.standardise(data, dim)
+            elif method == 'centre':
+                return statistics.centre(data, dim)
+            elif method == 'normalise':
+                # use min/max of data or given min/max
+                raise NotImplementedError
+            else:
+                raise NotImplementedError
+
+        def f_apply(data, method, mean, std):
+            if method == "standardise":
+                return mean, std, statistics.standardise_apply(data, mean, std)
+            elif method == "centre":
+                return mean, None, statistics.centre_apply(data, mean)
+            else:
+                raise NotImplementedError
+
+        opts = opts or {}
+        opts_updated = {}
+        if not inverse:
+            transformed_values = []
+            for var in data_in.variables.values:
+                data_var = data_in.sel(variables=[var])  # ToDo: replace hardcoded variables dim
+                var_opts = opts.get(var, {})
+                _method = var_opts.get("method", "standardise")
+                _mean = var_opts.get("mean", None)
+                _std = var_opts.get("std", None)
+                mean, std, values = locals()["f" if _mean is None else "f_apply"](data_var, _method, _mean, _std)
+                opts_updated[var] = {"method": _method, "mean": mean, "std": std}
+                transformed_values.append(values)
+            return xr.concat(transformed_values, dim="variables"), opts_updated  # ToDo: replace hardcoded variables dim
+        else:
+            self.inverse_transform(data_in)  # ToDo: add return statement
 
     @TimeTrackingWrapper
     def setup_samples(self):
@@ -159,9 +233,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
 
     def set_inputs_and_targets(self):
         inputs = self._data.sel({self.target_dim: helpers.to_list(self.variables)})
-        targets = self._data.sel({self.target_dim: self.target_var})
-        self.input_data.data = inputs
-        self.target_data.data = targets
+        targets = self._data.sel(
+            {self.target_dim: helpers.to_list(self.target_var)})  # ToDo: is it right to expand this dim??
+        self.input_data = inputs
+        self.target_data = targets
 
     def make_samples(self):
         self.make_history_window(self.target_dim, self.window_history_size, self.time_dim)
@@ -395,7 +470,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         :param dim_name_of_shift: Dimension along shift will be applied
         """
         window = -abs(window)
-        data = self.input_data.data
+        data = self.input_data
         self.history = self.shift(data, dim_name_of_shift, window, offset=self.window_history_offset)
 
     def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str,
@@ -412,7 +487,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         :param window: lead time of label
         """
         window = abs(window)
-        data = self.target_data.data
+        data = self.target_data
         self.label = self.shift(data, dim_name_of_shift, window)
 
     def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None:
@@ -425,7 +500,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         :param target_var: Name of observation variable(s) in 'dimension'
         :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied
         """
-        data = self.target_data.data
+        data = self.target_data
         self.observation = self.shift(data, dim_name_of_shift, 0)
 
     def remove_nan(self, dim: str) -> None:
@@ -481,7 +556,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         return data.loc[{coord: slice(str(start), str(end))}]
 
     @staticmethod
-    def setup_transformation(transformation: statistics.TransformationClass):
+    def setup_transformation(transformation: Union[None, dict, Tuple]) -> Tuple[Optional[dict], Optional[dict]]:
         """
         Set up transformation by extracting all relevant information.
 
@@ -491,17 +566,17 @@ class DataHandlerSingleStation(AbstractDataHandler):
           design behaviour)
         """
         if transformation is None:
-            return statistics.DataClass(), statistics.DataClass()
-        elif isinstance(transformation, statistics.DataClass):
-            return transformation, transformation
-        elif isinstance(transformation, statistics.TransformationClass):
-            return copy.deepcopy(transformation.inputs), copy.deepcopy(transformation.targets)
+            return None, None
+        elif isinstance(transformation, dict):
+            return copy.deepcopy(transformation), copy.deepcopy(transformation)
+        elif isinstance(transformation, tuple) and len(transformation) == 2:
+            return copy.deepcopy(transformation)
         else:
             raise NotImplementedError("Cannot handle this.")
 
     def transform(self, data_class, dim: Union[str, int] = 0, transform_method: str = 'standardise',
                   inverse: bool = False, mean=None,
-                  std=None, min=None, max=None) -> None:
+                  std=None, min=None, max=None, opts=None) -> None:
         """
         Transform data according to given transformation settings.
 
@@ -614,7 +689,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
         # update X and Y
         self.make_samples()
 
-    def get_transformation_targets(self) -> Tuple[data_or_none, data_or_none, str]:
+    def get_transformation_targets(self) -> Dict:
         """
         Extract transformation statistics and method.
 
@@ -622,9 +697,9 @@ class DataHandlerSingleStation(AbstractDataHandler):
         depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are
         returned with None as fill value.
 
-        :return: mean, standard deviation and transformation method
+        :return: dict with all transformation information
         """
-        return self.target_data.mean, self.target_data.std, self.target_data.transform_method
+        return copy.deepcopy(self._transformation[1])
 
 
 if __name__ == "__main__":
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 291bbc6616314db61282c380a6b3e105d8b6248a..070da625fdcfb4a8ccdc2a449096a8f9f0a2e78f 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -243,45 +243,74 @@ class DefaultDataHandler(AbstractDataHandler):
         """
 
         sp_keys = {k: copy.deepcopy(kwargs[k]) for k in cls._requirements if k in kwargs}
-        transformation_class = sp_keys.get("transformation", None)
-        if transformation_class is None:
+        transformation_dict = sp_keys.get("transformation", None)
+        if transformation_dict is None:
             return
-
-        transformation_inputs = transformation_class.inputs
-        if transformation_inputs.mean is not None:
-            return
-        means = [None, None]
-        stds = [None, None]
-
-        if multiprocessing.cpu_count() > 1:  # parallel solution
-            logging.info("use parallel transformation approach")
-            pool = multiprocessing.Pool()
-            logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
-            output = [
-                pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys)
-                for station in set_stations]
-            for p in output:
-                dh, s = p.get()
-                if dh is not None:
-                    for i, data in enumerate([dh.input_data, dh.target_data]):
-                        means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
-                        stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
-        else:  # serial solution
-            logging.info("use serial transformation approach")
-            for station in set_stations:
-                dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
-                if dh is not None:
-                    for i, data in enumerate([dh.input_data, dh.target_data]):
-                        means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
-                        stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
-
-        if means[0] is None:
-            return None
-        transformation_class.inputs.mean = means[0].mean("Stations")
-        transformation_class.inputs.std = stds[0].mean("Stations")
-        transformation_class.targets.mean = means[1].mean("Stations")
-        transformation_class.targets.std = stds[1].mean("Stations")
-        return transformation_class
+        if isinstance(transformation_dict, dict):  # tuple for (input, target) transformation
+            transformation_dict = copy.deepcopy(transformation_dict), copy.deepcopy(transformation_dict)
+        for station in set_stations:
+            dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
+            if dh is not None:
+                for i, transformation in enumerate(dh._transformation):
+                    for var in transformation.keys():
+                        if var not in transformation_dict[i].keys():
+                            transformation_dict[i][var] = {}
+                        opts = transformation[var]
+                        assert transformation_dict[i][var].get("method", opts["method"]) == opts["method"]
+                        transformation_dict[i][var]["method"] = opts["method"]
+                        for k in ["mean", "std"]:
+                            old = transformation_dict[i][var].get(k, None)
+                            new = opts.get(k)
+                            transformation_dict[i][var][k] = new if old is None else old.combine_first(new)
+        pop_list = []
+        for i, transformation in enumerate(transformation_dict):
+            for k in transformation.keys():
+                try:
+                    if transformation[k]["mean"] is not None:
+                        transformation_dict[i][k]["mean"] = transformation[k]["mean"].mean("Stations")
+                    if transformation[k]["std"] is not None:
+                        transformation_dict[i][k]["std"] = transformation[k]["std"].mean("Stations")
+                except KeyError:
+                    pop_list.append((i, k))
+        for (i, k) in pop_list:
+            transformation_dict[i].pop(k)
+        return transformation_dict
+
+        # transformation_inputs = transformation_dict.inputs
+        # if transformation_inputs.mean is not None:
+        #     return
+        # means = [None, None]
+        # stds = [None, None]
+
+        # if multiprocessing.cpu_count() > 1:  # parallel solution
+        #     logging.info("use parallel transformation approach")
+        #     pool = multiprocessing.Pool()
+        #     logging.info(f"running {getattr(pool, '_processes')} processes in parallel")
+        #     output = [
+        #         pool.apply_async(f_proc, args=(cls.data_handler_transformation, station), kwds=sp_keys)
+        #         for station in set_stations]
+        #     for p in output:
+        #         dh, s = p.get()
+        #         if dh is not None:
+        #             for i, data in enumerate([dh.input_data, dh.target_data]):
+        #                 means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
+        #                 stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
+        # else:  # serial solution
+        #     logging.info("use serial transformation approach")
+        #     for station in set_stations:
+        #         dh, s = f_proc(cls.data_handler_transformation, station, **sp_keys)
+        #         if dh is not None:
+        #             for i, data in enumerate([dh.input_data, dh.target_data]):
+        #                 means[i] = data.mean.copy(deep=True) if means[i] is None else means[i].combine_first(data.mean)
+        #                 stds[i] = data.std.copy(deep=True) if stds[i] is None else stds[i].combine_first(data.std)
+
+        # if means[0] is None:
+        #     return None
+        # transformation_dict.inputs.mean = means[0].mean("Stations")
+        # transformation_dict.inputs.std = stds[0].mean("Stations")
+        # transformation_dict.targets.mean = means[1].mean("Stations")
+        # transformation_dict.targets.std = stds[1].mean("Stations")
+        # return transformation_dict
 
     def get_coordinates(self):
         return self.id_class.get_coordinates()
diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index 546a463650ccca4c6f7e2b63b3afb01db9d90a40..bfc1490d9826be008847502a6181c492060acda2 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -17,37 +17,14 @@ from mlair.helpers import to_list, remove_items
 Data = Union[xr.DataArray, pd.DataFrame]
 
 
-class DataClass:
-
-    def __init__(self, data=None, mean=None, std=None, max=None, min=None, transform_method=None):
-        self.data = data
-        self.mean = mean
-        self.std = std
-        self.max = max
-        self.min = min
-        self.transform_method = transform_method
-        self._method = None
-
-    def as_dict(self):
-        return remove_items(self.__dict__, "_method")
-
-
-class TransformationClass:
-
-    def __init__(self, inputs_mean=None, inputs_std=None, inputs_method=None, targets_mean=None, targets_std=None,
-                 targets_method=None):
-        self.inputs = DataClass(mean=inputs_mean, std=inputs_std, transform_method=inputs_method)
-        self.targets = DataClass(mean=targets_mean, std=targets_std, transform_method=targets_method)
-
-
-def apply_inverse_transformation(data: Data, mean: Data, std: Data = None, method: str = "standardise") -> Data:
+def apply_inverse_transformation(data: Data, method: str = "standardise", mean: Data = None, std: Data = None) -> Data:
     """
     Apply inverse transformation for given statistics.
 
     :param data: transform this data back
+    :param method: transformation method
     :param mean: mean of transformation
     :param std: standard deviation of transformation (optional)
-    :param method: transformation method
 
     :return: inverse transformed data
     """
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 52b1121e7b8f165476d3c27d9e24b077a731f8e5..b1426ed289783bcd2fe7939d280aa20d6e703860 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -147,11 +147,12 @@ class PlotMonthlySummary(AbstractPlotClass):
     """
 
     def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None,
-                 plot_folder: str = ".", target_var_unit: str = 'ppb'):
+                 plot_folder: str = ".", target_var_unit: str = 'ppb', model_name="nn"):
         """Set attributes and create plot."""
         super().__init__(plot_folder, "monthly_summary_box_plot")
         self._data_path = data_path
         self._data_name = name
+        self._model_name = model_name
         self._data = self._prepare_data(stations)
         self._window_lead_time = self._get_window_lead_time(window_lead_time)
         self._plot(target_var, target_var_unit)
@@ -173,14 +174,14 @@ class PlotMonthlySummary(AbstractPlotClass):
             file_name = os.path.join(self._data_path, self._data_name % station)
             data = xr.open_dataarray(file_name)
 
-            data_cnn = data.sel(type="CNN").squeeze()
-            if len(data_cnn.shape) > 1:
-                data_cnn = data_cnn.assign_coords(ahead=[f"{days}d" for days in data_cnn.coords["ahead"].values])
+            data_nn = data.sel(type=self._model_name).squeeze()
+            if len(data_nn.shape) > 1:
+                data_nn = data_nn.assign_coords(ahead=[f"{days}d" for days in data_nn.coords["ahead"].values])
 
             data_obs = data.sel(type="obs", ahead=1).squeeze()
             data_obs.coords["ahead"] = "obs"
 
-            data_concat = xr.concat([data_obs, data_cnn], dim="ahead")
+            data_concat = xr.concat([data_obs, data_nn], dim="ahead")
             data_concat = data_concat.drop_vars("type")
 
             new_index = data_concat.index.values.astype("datetime64[M]").astype(int) % 12 + 1
@@ -347,7 +348,7 @@ class PlotStationMap(AbstractPlotClass):
             return arr[1] - arr[0], arr[3] - arr[2]
 
         def find_ratio(delta, reference=5):
-            return max(abs(reference / delta[0]), abs(reference / delta[1]))
+            return min(max(abs(reference / delta[0]), abs(reference / delta[1])), 5)
 
         extent = self._ax.get_extent(crs=ccrs.PlateCarree())
         ratio = find_ratio(diff(extent))
@@ -376,7 +377,7 @@ class PlotConditionalQuantiles(AbstractPlotClass):
     :param plot_folder: path where the plots are stored
     :param plot_per_seasons: if `True' create cond. quantile plots for _seasons (DJF, MAM, JJA, SON) individually
     :param rolling_window: smoothing of quantiles (3 is used by Murphy et al.)
-    :param model_mame: name of the model prediction as stored in netCDF file (for example "CNN")
+    :param model_name: name of the model prediction as stored in netCDF file (for example "nn")
     :param obs_name: name of observation as stored in netCDF file (for example "obs")
     :param kwargs: Some further arguments which are listed in self._opts
     """
@@ -389,13 +390,13 @@ class PlotConditionalQuantiles(AbstractPlotClass):
     warnings.filterwarnings("ignore", message="Attempted to set non-positive bottom ylim on a log-scaled axis.")
 
     def __init__(self, stations: List, data_pred_path: str, plot_folder: str = ".", plot_per_seasons=True,
-                 rolling_window: int = 3, model_mame: str = "CNN", obs_name: str = "obs", **kwargs):
+                 rolling_window: int = 3, model_name: str = "nn", obs_name: str = "obs", **kwargs):
         """Initialise."""
         super().__init__(plot_folder, "conditional_quantiles")
         self._data_pred_path = data_pred_path
         self._stations = stations
         self._rolling_window = rolling_window
-        self._model_name = model_mame
+        self._model_name = model_name
         self._obs_name = obs_name
         self._opts = self._get_opts(kwargs)
         self._seasons = ['DJF', 'MAM', 'JJA', 'SON'] if plot_per_seasons is True else ""
@@ -619,7 +620,7 @@ class PlotClimatologicalSkillScore(AbstractPlotClass):
     :param plot_folder: path to save the plot (default: current directory)
     :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True)
     :param extra_name_tag: additional tag that can be included in the plot name (default "")
-    :param model_setup: architecture type to specify plot name (default "CNN")
+    :param model_setup: architecture type to specify plot name (default "")
 
     """
 
@@ -998,11 +999,13 @@ class PlotTimeSeries:
     """
 
     def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = ".",
-                 sampling="daily"):
+                 sampling="daily", model_name="nn", obs_name="obs"):
         """Initialise."""
         self._data_path = data_path
         self._data_name = name
         self._stations = stations
+        self._model_name = model_name
+        self._obs_name = obs_name
         self._window_lead_time = self._get_window_lead_time(window_lead_time)
         self._sampling = self._get_sampling(sampling)
         self._plot(plot_folder)
@@ -1034,7 +1037,7 @@ class PlotTimeSeries:
         logging.debug(f"... preprocess station {station}")
         file_name = os.path.join(self._data_path, self._data_name % station)
         data = xr.open_dataarray(file_name)
-        return data.sel(type=["CNN", "obs"])
+        return data.sel(type=[self._model_name, self._obs_name])
 
     def _plot(self, plot_folder):
         pdf_pages = self._create_pdf_pages(plot_folder)
@@ -1088,7 +1091,8 @@ class PlotTimeSeries:
     def _plot_ahead(self, ax, data):
         color = sns.color_palette("Blues_d", self._window_lead_time).as_hex()
         for ahead in data.coords["ahead"].values:
-            plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze().shift(index=ahead)
+            plot_data = data.sel(type=self._model_name, ahead=ahead).drop(["type", "ahead"]).squeeze().shift(
+                index=ahead)
             label = f"{ahead}{self._sampling}"
             ax.plot(plot_data, color=color[ahead - 1], label=label)
 
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index 34da6cb33828d8fd2b34d15dd50da3e30c8af17e..ee8506ee6c445ccf9ac93dc0498804841538311d 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -9,7 +9,7 @@ from typing import Union, Dict, Any, List, Callable
 from mlair.configuration import path_config
 from mlair import helpers
 from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT, DEFAULT_NETWORK, DEFAULT_STATION_TYPE, \
-    DEFAULT_START, DEFAULT_END, DEFAULT_WINDOW_HISTORY_SIZE, DEFAULT_OVERWRITE_LOCAL_DATA, DEFAULT_TRANSFORMATION, \
+    DEFAULT_START, DEFAULT_END, DEFAULT_WINDOW_HISTORY_SIZE, DEFAULT_OVERWRITE_LOCAL_DATA, \
     DEFAULT_HPC_LOGIN_LIST, DEFAULT_HPC_HOST_LIST, DEFAULT_CREATE_NEW_MODEL, DEFAULT_TRAIN_MODEL, \
     DEFAULT_FRACTION_OF_TRAINING, DEFAULT_EXTREME_VALUES, DEFAULT_EXTREMES_ON_RIGHT_TAIL_ONLY, DEFAULT_PERMUTE_DATA, \
     DEFAULT_BATCH_SIZE, DEFAULT_EPOCHS, DEFAULT_TARGET_VAR, DEFAULT_TARGET_DIM, DEFAULT_WINDOW_LEAD_TIME, \
@@ -294,7 +294,7 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("window_history_size", window_history_size, default=DEFAULT_WINDOW_HISTORY_SIZE)
         self._set_param("overwrite_local_data", overwrite_local_data, default=DEFAULT_OVERWRITE_LOCAL_DATA,
                         scope="preprocessing")
-        self._set_param("transformation", transformation, default=DEFAULT_TRANSFORMATION)
+        self._set_param("transformation", transformation, default=None)
         self._set_param("transformation", None, scope="preprocessing")
         self._set_param("data_handler", data_handler, default=DefaultDataHandler)
 
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 020f0a42b9e205800b39ccbfdec20cbba8364f1f..39f5f450750b5af9d00a78d632caa66df9dbe0c4 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -406,7 +406,7 @@ class PostProcessing(RunEnvironment):
             observation_data = data.get_observation()
 
             # get scaling parameters
-            mean, std, transformation_method = data.get_transformation_Y()
+            transformation_opts = data.get_transformation_Y()
 
             for normalised in [True, False]:
                 # create empty arrays
@@ -414,20 +414,17 @@ class PostProcessing(RunEnvironment):
                     target_data, count=4)
 
                 # nn forecast
-                nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method,
-                                                         normalised)
+                nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_opts, normalised)
 
                 # persistence
-                persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, mean, std,
-                                                                           transformation_method, normalised)
+                persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
+                                                                           transformation_opts, normalised)
 
                 # ols
-                ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method,
-                                                           normalised)
+                ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_opts, normalised)
 
                 # observation
-                observation = self._create_observation(target_data, observation, mean, std, transformation_method,
-                                                       normalised)
+                observation = self._create_observation(target_data, observation, transformation_opts, normalised)
 
                 # merge all predictions
                 full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
@@ -468,10 +465,7 @@ class PostProcessing(RunEnvironment):
         forecast.coords["type"] = [competitor_name]
         return forecast
 
-
-    @staticmethod
-    def _create_observation(data, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
-                            normalised: bool) -> xr.DataArray:
+    def _create_observation(self, data, _, transformation_opts: dict, normalised: bool) -> xr.DataArray:
         """
         Create observation as ground truth from given data.
 
@@ -486,11 +480,11 @@ class PostProcessing(RunEnvironment):
         :return: filled data array with observation
         """
         if not normalised:
-            data = statistics.apply_inverse_transformation(data, mean, std, transformation_method)
+            data = self._inverse_transformation(data, transformation_opts)
         return data
 
-    def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, mean: xr.DataArray,
-                             std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
+    def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, transformation_opts: dict,
+                             normalised: bool) -> xr.DataArray:
         """
         Create ordinary least square model forecast with given input data.
 
@@ -509,11 +503,11 @@ class PostProcessing(RunEnvironment):
         target_shape = ols_prediction.values.shape
         ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
         if not normalised:
-            ols_prediction = statistics.apply_inverse_transformation(ols_prediction, mean, std, transformation_method)
+            ols_prediction = self._inverse_transformation(ols_prediction, transformation_opts)
         return ols_prediction
 
-    def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, mean: xr.DataArray,
-                                     std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
+    def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, transformation_opts: dict,
+                                     normalised: bool) -> xr.DataArray:
         """
         Create persistence forecast with given data.
 
@@ -532,12 +526,11 @@ class PostProcessing(RunEnvironment):
         tmp_persi = data.copy()
         persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T
         if not normalised:
-            persistence_prediction = statistics.apply_inverse_transformation(persistence_prediction, mean, std,
-                                                                             transformation_method)
+            persistence_prediction = self._inverse_transformation(persistence_prediction, transformation_opts)
         return persistence_prediction
 
-    def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray,
-                            std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
+    def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_opts: dict,
+                            normalised: bool) -> xr.DataArray:
         """
         Create NN forecast for given input data.
 
@@ -564,9 +557,29 @@ class PostProcessing(RunEnvironment):
         else:
             raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
         if not normalised:
-            nn_prediction = statistics.apply_inverse_transformation(nn_prediction, mean, std, transformation_method)
+            nn_prediction = self._inverse_transformation(nn_prediction, transformation_opts)
         return nn_prediction
 
+    def _inverse_transformation(self, data, transformation_opts):
+        transformed_values = []
+        for var in to_list(data.variables.values.tolist()):
+            if "variables" in data.dims:
+                data_var = data.sel(variables=[var])  # ToDo: replace hardcoded variables dim
+            else:
+                data_var = data
+            var_opts = transformation_opts.get(var, {})
+            _method = var_opts.get("method", "standardise")
+            _mean = var_opts.get("mean", None)
+            _std = var_opts.get("std", None)
+            values = statistics.apply_inverse_transformation(data_var, _method, _mean,
+                                                             _std)  # ToDo: replace hardcoded variables dim
+            transformed_values.append(values)  # ToDo: replace hardcoded variables dim
+        res = xr.concat(transformed_values, dim="variables")  # ToDo: replace hardcoded variables dim
+        if res.shape == data.shape:
+            return res
+        else:
+            return res.squeeze("variables")  # ToDo: replace hardcoded variables dim
+
     @staticmethod
     def _create_empty_prediction_arrays(target_data, count=1):
         """
@@ -630,22 +643,19 @@ class PostProcessing(RunEnvironment):
         """
         try:
             data = self.train_val_data[station]
-            # target_data = data.get_Y(as_numpy=False)
             observation = data.get_observation()
-            mean, std, transformation_method = data.get_transformation_Y()
-            # external_data = self._create_observation(target_data, None, mean, std, transformation_method, normalised=False)
-            # external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
-            external_data = self._create_observation(observation, None, mean, std, transformation_method, normalised=False)
+            transformation_opts = data.get_transformation_Y()
+            external_data = self._create_observation(observation, None, transformation_opts, normalised=False)
             return external_data.rename({external_data.dims[0]: 'index'})
         except (IndexError, KeyError):
             return None
 
     def calculate_skill_scores(self) -> Tuple[Dict, Dict]:
         """
-        Calculate skill scores of CNN forecast.
+        Calculate skill scores of NN forecast.
 
-        The competitive skill score compares the CNN prediction with persistence and ordinary least squares forecasts.
-        Whereas, the climatological skill scores evaluates the CNN prediction in terms of meaningfulness in comparison
+        The competitive skill score compares the NN prediction with persistence and ordinary least squares forecasts.
+        Whereas, the climatological skill scores evaluates the NN prediction in terms of meaningfulness in comparison
         to different climatological references.
 
         :return: competitive and climatological skill scores