diff --git a/src/data_handling/advanced_data_handling.py b/src/data_handling/advanced_data_handling.py
index f48bbb22cb5a52df540bf76b517f38e7b062511b..78f5792ef190eeb06b60d998a808ed5c8c1c7899 100644
--- a/src/data_handling/advanced_data_handling.py
+++ b/src/data_handling/advanced_data_handling.py
@@ -191,6 +191,12 @@ class DefaultDataPreparation(AbstractDataPreparation):
             Y = Y_original.sel({dim: intersect})
         self._X, self._Y = X, Y
 
+    def get_observation(self):
+        return self.id_class.observation.copy().squeeze()
+
+    def get_transformation_Y(self):
+        return self.id_class.get_transformation_information()
+
     def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False,
                           timedelta: Tuple[int, str] = (1, 'm'), dim="datetime"):
         """
@@ -265,12 +271,10 @@ class DefaultDataPreparation(AbstractDataPreparation):
         transformation_dict = sp_keys.pop("transformation")
         if transformation_dict is None:
             return
-
         scope = transformation_dict.pop("scope")
         method = transformation_dict.pop("method")
         if transformation_dict.pop("mean", None) is not None:
             return
-
         mean, std = None, None
         for station in set_stations:
             try:
@@ -286,8 +290,6 @@ class DefaultDataPreparation(AbstractDataPreparation):
         return {"scope": scope, "method": method, "mean": mean_estimated, "std": std_estimated}
 
 
-
-
 def run_data_prep():
 
     data = DummyDataSingleStation("main_class")
diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py
index 447500203c7f92c8db5f4ece6edc195587565b6b..453a203cc80aa950e2d5d0097c6f9bd3c3b15a7d 100644
--- a/src/data_handling/data_preparation.py
+++ b/src/data_handling/data_preparation.py
@@ -603,6 +603,29 @@ class StationPrep(AbstractStationPrep):
         self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method)
         self._transform_method = None
 
+    def get_transformation_information(self, variable: str = None) -> Tuple[data_or_none, data_or_none, str]:
+        """
+        Extract transformation statistics and method.
+
+        Get mean and standard deviation for given variable and the transformation method if set. If a transformation
+        depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are
+        returned with None as fill value.
+
+        :param variable: Variable for which the information on transformation is requested.
+
+        :return: mean, standard deviation and transformation method
+        """
+        variable = self.target_var if variable is None else variable
+        try:
+            mean = self.mean.sel({'variables': variable}).values
+        except AttributeError:
+            mean = None
+        try:
+            std = self.std.sel({'variables': variable}).values
+        except AttributeError:
+            std = None
+        return mean, std, self._transform_method
+
 
 class AbstractDataPrep(object):
     """
diff --git a/src/model_modules/linear_model.py b/src/model_modules/linear_model.py
index 85c062119d881635bf54aadb7491a3c0298e64f5..341c787e3060fd7e7cc3ff468ba40add9b9936d2 100644
--- a/src/model_modules/linear_model.py
+++ b/src/model_modules/linear_model.py
@@ -55,7 +55,7 @@ class OrdinaryLeastSquaredModel:
 
     def predict(self, data):
         """Apply OLS model on data."""
-        data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add")
+        data = sm.add_constant(np.concatenate(self.flatten(data), axis=1), has_constant="add")
         return np.atleast_2d(self.model.predict(data))
 
     @staticmethod
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 2512244c8f9516becfb0edec48a3c9f82e5643de..66744b8022f71103c8686de8cd85cc4016e7f932 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -312,12 +312,14 @@ class PostProcessing(RunEnvironment):
         be found inside `forecast_path`.
         """
         logging.debug("start make_prediction")
+        time_dimension = self.data_store.get("interpolate_dim")
         for i, data in enumerate(self.test_data):
             input_data = data.get_X()
-            target_data = data.get_Y()
+            target_data = data.get_Y(as_numpy=False)
+            observation_data = data.get_observation()
 
             # get scaling parameters
-            # mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
+            mean, std, transformation_method = data.get_transformation_Y()
 
             for normalised in [True, False]:
                 # create empty arrays
@@ -329,7 +331,7 @@ class PostProcessing(RunEnvironment):
                                                          normalised)
 
                 # persistence
-                persistence_prediction = self._create_persistence_forecast(data, persistence_prediction, mean, std,
+                persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, mean, std,
                                                                            transformation_method, normalised)
 
                 # ols
@@ -337,11 +339,12 @@ class PostProcessing(RunEnvironment):
                                                            normalised)
 
                 # observation
-                observation = self._create_observation(data, observation, mean, std, transformation_method, normalised)
+                observation = self._create_observation(target_data, observation, mean, std, transformation_method, normalised)
 
                 # merge all predictions
-                full_index = self.create_fullindex(data.data.indexes['datetime'], self._get_frequency())
-                all_predictions = self.create_forecast_arrays(full_index, list(data.label.indexes['window']),
+                full_index = self.create_fullindex(observation_data.indexes[time_dimension], self._get_frequency())
+                all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes['window']),
+                                                              time_dimension,
                                                               CNN=nn_prediction,
                                                               persi=persistence_prediction,
                                                               obs=observation,
@@ -350,7 +353,7 @@ class PostProcessing(RunEnvironment):
                 # save all forecasts locally
                 path = self.data_store.get("forecast_path")
                 prefix = "forecasts_norm" if normalised else "forecasts"
-                file = os.path.join(path, f"{prefix}_{data.station[0]}_test.nc")
+                file = os.path.join(path, f"{prefix}_{str(data)}_test.nc")
                 all_predictions.to_netcdf(file)
 
     def _get_frequency(self) -> str:
@@ -359,14 +362,14 @@ class PostProcessing(RunEnvironment):
         return getter.get(self._sampling, None)
 
     @staticmethod
-    def _create_observation(data: DataPrepJoin, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
+    def _create_observation(data, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str,
                             normalised: bool) -> xr.DataArray:
         """
         Create observation as ground truth from given data.
 
         Inverse transformation is applied to the ground truth to get the output in the original space.
 
-        :param data: transposed observation from DataPrep
+        :param data: observation
         :param mean: mean of target value transformation
         :param std: standard deviation of target value transformation
         :param transformation_method: target values transformation method
@@ -374,10 +377,9 @@ class PostProcessing(RunEnvironment):
 
         :return: filled data array with observation
         """
-        obs = data.label.copy()
         if not normalised:
-            obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method)
-        return obs
+            data = statistics.apply_inverse_transformation(data, mean, std, transformation_method)
+        return data
 
     def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, mean: xr.DataArray,
                              std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
@@ -398,12 +400,11 @@ class PostProcessing(RunEnvironment):
         tmp_ols = self.ols_model.predict(input_data)
         if not normalised:
             tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method)
-        tmp_ols = np.expand_dims(tmp_ols, axis=1)
         target_shape = ols_prediction.values.shape
         ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
         return ols_prediction
 
-    def _create_persistence_forecast(self, data: DataPrepJoin, persistence_prediction: xr.DataArray, mean: xr.DataArray,
+    def _create_persistence_forecast(self, data, persistence_prediction: xr.DataArray, mean: xr.DataArray,
                                      std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray:
         """
         Create persistence forecast with given data.
@@ -411,7 +412,7 @@ class PostProcessing(RunEnvironment):
         Persistence is deviated from the value at t=0 and applied to all following time steps (t+1, ..., t+window).
         Inverse transformation is applied to the forecast to get the output in the original space.
 
-        :param data: DataPrep
+        :param data: observation
         :param persistence_prediction: empty array in right shape to fill with data
         :param mean: mean of target value transformation
         :param std: standard deviation of target value transformation
@@ -420,12 +421,11 @@ class PostProcessing(RunEnvironment):
 
         :return: filled data array with persistence predictions
         """
-        tmp_persi = data.observation.copy().sel({'window': 0})
+        tmp_persi = data.copy()
         if not normalised:
             tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
         window_lead_time = self.data_store.get("window_lead_time")
-        persistence_prediction.values = np.expand_dims(np.tile(tmp_persi.squeeze('Stations'), (window_lead_time, 1)),
-                                                       axis=1)
+        persistence_prediction.values = np.tile(tmp_persi, (window_lead_time, 1)).T
         return persistence_prediction
 
     def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray,
@@ -450,17 +450,19 @@ class PostProcessing(RunEnvironment):
         if not normalised:
             tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
         if isinstance(tmp_nn, list):
-            nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1], axis=1), 2, 0)
+            nn_prediction.values = tmp_nn[-1]
         elif tmp_nn.ndim == 3:
-            nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
+            nn_prediction.values = tmp_nn[-1, ...]
         elif tmp_nn.ndim == 2:
-            nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
+            nn_prediction.values = tmp_nn
         else:
             raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
         return nn_prediction
 
     @staticmethod
     def _create_empty_prediction_arrays(target_data, count=1):
+        """
+        Create array to collect all predictions. Expand target data by a station dimension. """
         return [target_data.copy() for _ in range(count)]
 
     @staticmethod
@@ -489,7 +491,7 @@ class PostProcessing(RunEnvironment):
         return index
 
     @staticmethod
-    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], **kwargs):
+    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs):
         """
         Combine different forecast types into single xarray.
 
@@ -504,12 +506,8 @@ class PostProcessing(RunEnvironment):
         res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
                            coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type'])
         for k, v in kwargs.items():
-            try:
-                match_index = np.stack(set(res.index.values) & set(v.index.values))
-                res.loc[match_index, :, k] = v.loc[match_index]
-            except AttributeError:  # v is xarray type and has no attribute .index
-                match_index = np.stack(set(res.index.values) & set(v.indexes['datetime'].values))
-                res.loc[match_index, :, k] = v.sel({'datetime': match_index}).squeeze('Stations').transpose()
+            match_index = np.stack(set(res.index.values) & set(v.indexes[time_dimension].values))
+            res.loc[match_index, :, k] = v.loc[match_index]
         return res
 
     def _get_external_data(self, station: str) -> Union[xr.DataArray, None]: