diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py
index 26ccf69c85e999c540e656a2ceac5737390a579e..f085d18bb8d33839a0e3b5f6f3d5ada92134e7f6 100644
--- a/mlair/data_handler/abstract_data_handler.py
+++ b/mlair/data_handler/abstract_data_handler.py
@@ -36,6 +36,13 @@ class AbstractDataHandler:
     def transformation(cls, *args, **kwargs):
         return None
 
+    def apply_transformation(self, data, inverse=False, **kwargs):
+        """
+        This method must return transformed data. The flag inverse can be used to trigger either transformation or its
+        inverse method.
+        """
+        raise NotImplementedError
+
     def get_X(self, upsampling=False, as_numpy=False):
         raise NotImplementedError
 
diff --git a/mlair/data_handler/data_handler_single_station.py b/mlair/data_handler/data_handler_single_station.py
index 888554c1fd04cf6efbf22e3732fafc7e70760197..b176ccd6a3d9abf3d372c931b2182eaa3da95920 100644
--- a/mlair/data_handler/data_handler_single_station.py
+++ b/mlair/data_handler/data_handler_single_station.py
@@ -213,7 +213,7 @@ class DataHandlerSingleStation(AbstractDataHandler):
                 transformed_values.append(values)
             return xr.concat(transformed_values, dim="variables"), opts_updated  # ToDo: replace hardcoded variables dim
         else:
-            self.inverse_transform(data_in, opts)  # ToDo: add return statement
+            return self.inverse_transform(data_in, opts)  # ToDo: add return statement
 
     @TimeTrackingWrapper
     def setup_samples(self):
@@ -614,35 +614,48 @@ class DataHandlerSingleStation(AbstractDataHandler):
                 raise NotImplementedError
 
         transformed_values = []
+        sel_dim = "variables"  # ToDo: replace hardcoded variables dim
+        squeeze = False
+        if sel_dim in data_in.coords:
+            if sel_dim not in data_in.dims:
+                data_in = data_in.expand_dims(sel_dim)
+                squeeze = True
+        else:
+            raise IndexError(f"Could not find given dimension: {sel_dim}. Available is: {data_in.coords}")
         for var in data_in.variables.values:
-            data_var = data_in.sel(variables=[var])  # ToDo: replace hardcoded variables dim
+            data_var = data_in.sel(**{sel_dim: [var]})
             var_opts = opts.get(var, {})
             _method = var_opts.get("method", None)
             if _method is None:
-                raise AssertionError("Inverse transformation method is not set. Data cannot be inverse transformed.")
+                raise AssertionError(f"Inverse transformation method is not set for {var}.")
             _mean = var_opts.get("mean", None)
             _std = var_opts.get("std", None)
             self.check_inverse_transform_params(_method, _mean, _std)
-
             values = f_inverse(data_var, _method, _mean, _std)
             transformed_values.append(values)
-        return xr.concat(transformed_values, dim="variables")  # ToDo: replace hardcoded variables dim
+        res = xr.concat(transformed_values, dim=sel_dim)
+        return res.squeeze(sel_dim) if squeeze else res
 
-    def get_transformation_targets(self) -> Dict:
+    def apply_transformation(self, data, base=None, dim=0, inverse=False):
         """
-        Extract transformation statistics and method.
+        Apply transformation on external data. Specify if transformation should be based on parameters related to input
+        or target data using `base`. This method can also apply inverse transformation.
 
-        Get mean and standard deviation for target values and the transformation method if set. If a transformation
-        depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are
-        returned with None as fill value.
-
-        :return: dict with all transformation information
+        :param data:
+        :param base:
+        :param dim:
+        :param inverse:
+        :return:
         """
-        return copy.deepcopy(self._transformation[1])
-
-    def apply_transformation(self, data, transformation_opts, dim=0, inverse=False):
-
-        return self.transform(data, dim=dim, opts=transformation_opts, inverse=inverse)
+        if base in ["target", 1]:
+            pos = 1
+        elif base in ["input", 0]:
+            pos = 0
+        else:
+            raise ValueError("apply transformation requires a reference for transformation options. Please specify if"
+                             "you want to use input or target transformation using the parameter 'base'. Given was: " +
+                             base)
+        return self.transform(data, dim=dim, opts=self._transformation[pos], inverse=inverse)
 
 
 if __name__ == "__main__":
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 4b7ec3282d6214179921e8e9f763c63d3b403f71..af3e64f48d2c3c40cf536d848453659a277de80a 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -145,11 +145,8 @@ class DefaultDataHandler(AbstractDataHandler):
     def get_observation(self):
         return self.id_class.observation.copy().squeeze()
 
-    def get_transformation_Y(self):
-        return self.id_class.get_transformation_targets()
-
-    def apply_transformation(self, data, transformation_opts, dim=0, inverse=False):
-        return self.id_class.transform(data, dim=dim, opts=transformation_opts, inverse=inverse)
+    def apply_transformation(self, data, base="target", dim=0, inverse=False):
+        return self.id_class.apply_transformation(data, dim=dim, base=base, inverse=inverse)
 
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                           timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"):
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 39f5f450750b5af9d00a78d632caa66df9dbe0c4..eaa593050ccf5dcaf120e80b8c190302b83f054f 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -6,7 +6,7 @@ __date__ = '2019-12-11'
 import inspect
 import logging
 import os
-from typing import Dict, Tuple, Union, List
+from typing import Dict, Tuple, Union, List, Callable
 
 import keras
 import numpy as np
@@ -394,6 +394,57 @@ class PostProcessing(RunEnvironment):
         """
         Create predictions for NN, OLS, and persistence and add true observation as reference.
 
+        Predictions are filled in an array with full index range. Therefore, predictions can have missing values. All
+        predictions for a single station are stored locally under `<forecast/forecast_norm>_<station>_test.nc` and can
+        be found inside `forecast_path`.
+        """
+        logging.debug("start make_prediction")
+        time_dimension = self.data_store.get("time_dim")
+        for i, data in enumerate(self.test_data):
+            input_data = data.get_X()
+            target_data = data.get_Y(as_numpy=False)
+            observation_data = data.get_observation()
+
+            # get scaling parameters
+            transformation_func = data.apply_transformation
+
+            for normalised in [True, False]:
+                # create empty arrays
+                nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
+                    target_data, count=4)
+
+                # nn forecast
+                nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised)
+
+                # persistence
+                persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
+                                                                           transformation_func, normalised)
+
+                # ols
+                ols_prediction = self._create_ols_forecast(input_data, ols_prediction, transformation_func, normalised)
+
+                # observation
+                observation = self._create_observation(target_data, observation, transformation_func, normalised)
+
+                # merge all predictions
+                full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
+                prediction_dict = {self.forecast_indicator: nn_prediction,
+                                   "persi": persistence_prediction,
+                                   "obs": observation,
+                                   "ols": ols_prediction}
+                all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes['window']),
+                                                              time_dimension, **prediction_dict)
+
+                # save all forecasts locally
+                path = self.data_store.get("forecast_path")
+                prefix = "forecasts_norm" if normalised is True else "forecasts"
+                file = os.path.join(path, f"{prefix}_{str(data)}_test.nc")
+                all_predictions.to_netcdf(file)
+
+    def make_prediction_old(self):
+        """
+        Create predictions for NN, OLS, and persistence and add true observation as reference.
+
         Predictions are filled in an array with full index range. Therefore, predictions can have missing values. All
         predictions for a single station are stored locally under `<forecast/forecast_norm>_<station>_test.nc` and can
         be found inside `forecast_path`.
@@ -465,26 +516,24 @@ class PostProcessing(RunEnvironment):
         forecast.coords["type"] = [competitor_name]
         return forecast
 
-    def _create_observation(self, data, _, transformation_opts: dict, normalised: bool) -> xr.DataArray:
+    def _create_observation(self, data, _, transformation_func: Callable, normalised: bool) -> xr.DataArray:
         """
         Create observation as ground truth from given data.
 
         Inverse transformation is applied to the ground truth to get the output in the original space.
 
         :param data: observation
-        :param mean: mean of target value transformation
-        :param std: standard deviation of target value transformation
-        :param transformation_method: target values transformation method
+        :param transformation_func: a callable function to apply inverse transformation
         :param normalised: transform ground truth in original space if false, or use normalised predictions if true
 
         :return: filled data array with observation
         """
         if not normalised:
-            data = self._inverse_transformation(data, transformation_opts)
+            data = transformation_func(data, "target", inverse=True)
         return data
 
-    def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, transformation_opts: dict,
-                             normalised: bool) -> xr.DataArray:
+    def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray,
+                             transformation_func: Callable, normalised: bool) -> xr.DataArray:
         """
         Create ordinary least square model forecast with given input data.
 
@@ -492,9 +541,7 @@ class PostProcessing(RunEnvironment):
 
         :param input_data: transposed history from DataPrep
         :param ols_prediction: empty array in right shape to fill with data
-        :param mean: mean of target value transformation
-        :param std: standard deviation of target value transformation
-        :param transformation_method: target values transformation method
+        :param transformation_func: a callable function to apply inverse transformation
         :param normalised: transform prediction in original space if false, or use normalised predictions if true
 
         :return: filled data array with ols predictions
@@ -503,10 +550,10 @@ class PostProcessing(RunEnvironment):
         target_shape = ols_prediction.values.shape
         ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
         if not normalised:
-            ols_prediction = self._inverse_transformation(ols_prediction, transformation_opts)
+            ols_prediction = transformation_func(ols_prediction, "target", inverse=True)
         return ols_prediction
 
-    def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, transformation_opts: dict,
+    def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, transformation_func: Callable,
                                      normalised: bool) -> xr.DataArray:
         """
         Create persistence forecast with given data.
@@ -516,9 +563,7 @@ class PostProcessing(RunEnvironment):
 
         :param data: observation
         :param persistence_prediction: empty array in right shape to fill with data
-        :param mean: mean of target value transformation
-        :param std: standard deviation of target value transformation
-        :param transformation_method: target values transformation method
+        :param transformation_func: a callable function to apply inverse transformation
         :param normalised: transform prediction in original space if false, or use normalised predictions if true
 
         :return: filled data array with persistence predictions
@@ -526,10 +571,10 @@ class PostProcessing(RunEnvironment):
         tmp_persi = data.copy()
         persistence_prediction.values = np.tile(tmp_persi, (self.window_lead_time, 1)).T
         if not normalised:
-            persistence_prediction = self._inverse_transformation(persistence_prediction, transformation_opts)
+            persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True)
         return persistence_prediction
 
-    def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_opts: dict,
+    def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
                             normalised: bool) -> xr.DataArray:
         """
         Create NN forecast for given input data.
@@ -540,9 +585,7 @@ class PostProcessing(RunEnvironment):
 
         :param input_data: transposed history from DataPrep
         :param nn_prediction: empty array in right shape to fill with data
-        :param mean: mean of target value transformation
-        :param std: standard deviation of target value transformation
-        :param transformation_method: target values transformation method
+        :param transformation_func: a callable function to apply inverse transformation
         :param normalised: transform prediction in original space if false, or use normalised predictions if true
 
         :return: filled data array with nn predictions
@@ -557,29 +600,9 @@ class PostProcessing(RunEnvironment):
         else:
             raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
         if not normalised:
-            nn_prediction = self._inverse_transformation(nn_prediction, transformation_opts)
+            nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
         return nn_prediction
 
-    def _inverse_transformation(self, data, transformation_opts):
-        transformed_values = []
-        for var in to_list(data.variables.values.tolist()):
-            if "variables" in data.dims:
-                data_var = data.sel(variables=[var])  # ToDo: replace hardcoded variables dim
-            else:
-                data_var = data
-            var_opts = transformation_opts.get(var, {})
-            _method = var_opts.get("method", "standardise")
-            _mean = var_opts.get("mean", None)
-            _std = var_opts.get("std", None)
-            values = statistics.apply_inverse_transformation(data_var, _method, _mean,
-                                                             _std)  # ToDo: replace hardcoded variables dim
-            transformed_values.append(values)  # ToDo: replace hardcoded variables dim
-        res = xr.concat(transformed_values, dim="variables")  # ToDo: replace hardcoded variables dim
-        if res.shape == data.shape:
-            return res
-        else:
-            return res.squeeze("variables")  # ToDo: replace hardcoded variables dim
-
     @staticmethod
     def _create_empty_prediction_arrays(target_data, count=1):
         """