diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 1500cdab23fca3058bffc838b6855fd0b3455f3d..a48a82b25804da34d07807073b0c153408e4e028 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -8,6 +8,7 @@ import logging
 import os
 import sys
 import traceback
+import copy
 from typing import Dict, Tuple, Union, List, Callable
 
 import numpy as np
@@ -726,13 +727,15 @@ class PostProcessing(RunEnvironment):
             # get scaling parameters
             transformation_func = data.apply_transformation
 
+            nn_output = self.model.predict(input_data)
+
             for normalised in [True, False]:
                 # create empty arrays
                 nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
                     target_data, count=4)
 
                 # nn forecast
-                nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised)
+                nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised)
 
                 # persistence
                 persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
@@ -859,7 +862,7 @@ class PostProcessing(RunEnvironment):
             persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True)
         return persistence_prediction
 
-    def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
+    def _create_nn_forecast(self, nn_output: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
                             normalised: bool) -> xr.DataArray:
         """
         Create NN forecast for given input data.
@@ -868,22 +871,22 @@ class PostProcessing(RunEnvironment):
         output of the main branch is returned (not all minor branches, if the network has multiple output branches). The
         main branch is defined to be the last entry of all outputs.
 
-        :param input_data: transposed history from DataPrep
+        :param nn_output: Full NN model output
         :param nn_prediction: empty array in right shape to fill with data
         :param transformation_func: a callable function to apply inverse transformation
         :param normalised: transform prediction in original space if false, or use normalised predictions if true
 
         :return: filled data array with nn predictions
         """
-        tmp_nn = self.model.predict(input_data)
-        if isinstance(tmp_nn, list):
-            nn_prediction.values = tmp_nn[-1]
-        elif tmp_nn.ndim == 3:
-            nn_prediction.values = tmp_nn[-1, ...]
-        elif tmp_nn.ndim == 2:
-            nn_prediction.values = tmp_nn
+
+        if isinstance(nn_output, list):
+            nn_prediction.values = nn_output[-1]
+        elif nn_output.ndim == 3:
+            nn_prediction.values = nn_output[-1, ...]
+        elif nn_output.ndim == 2:
+            nn_prediction.values = nn_output
         else:
-            raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
+            raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {nn_output.dims}.")
         if not normalised:
             nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
         return nn_prediction