From 3150d01f2f800ca00befabf283bf16badb41e701 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Tue, 9 Aug 2022 08:59:41 +0200
Subject: [PATCH] remove model.prediction from normalized loop

---
 mlair/run_modules/post_processing.py | 25 ++++++++++++++-----------
 1 file changed, 14 insertions(+), 11 deletions(-)

diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 1500cdab..a48a82b2 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -8,6 +8,7 @@ import logging
 import os
 import sys
 import traceback
+import copy
 from typing import Dict, Tuple, Union, List, Callable
 
 import numpy as np
@@ -726,13 +727,15 @@ class PostProcessing(RunEnvironment):
             # get scaling parameters
             transformation_func = data.apply_transformation
 
+            nn_output = self.model.predict(input_data)
+
             for normalised in [True, False]:
                 # create empty arrays
                 nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
                     target_data, count=4)
 
                 # nn forecast
-                nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised)
+                nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised)
 
                 # persistence
                 persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
@@ -859,7 +862,7 @@ class PostProcessing(RunEnvironment):
             persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True)
         return persistence_prediction
 
-    def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
+    def _create_nn_forecast(self, nn_output: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
                             normalised: bool) -> xr.DataArray:
         """
         Create NN forecast for given input data.
@@ -868,22 +871,22 @@ class PostProcessing(RunEnvironment):
         output of the main branch is returned (not all minor branches, if the network has multiple output branches). The
         main branch is defined to be the last entry of all outputs.
 
-        :param input_data: transposed history from DataPrep
+        :param nn_output: Full NN model output
         :param nn_prediction: empty array in right shape to fill with data
         :param transformation_func: a callable function to apply inverse transformation
         :param normalised: transform prediction in original space if false, or use normalised predictions if true
 
         :return: filled data array with nn predictions
         """
-        tmp_nn = self.model.predict(input_data)
-        if isinstance(tmp_nn, list):
-            nn_prediction.values = tmp_nn[-1]
-        elif tmp_nn.ndim == 3:
-            nn_prediction.values = tmp_nn[-1, ...]
-        elif tmp_nn.ndim == 2:
-            nn_prediction.values = tmp_nn
+
+        if isinstance(nn_output, list):
+            nn_prediction.values = nn_output[-1]
+        elif nn_output.ndim == 3:
+            nn_prediction.values = nn_output[-1, ...]
+        elif nn_output.ndim == 2:
+            nn_prediction.values = nn_output
         else:
-            raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
+            raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {nn_output.dims}.")
         if not normalised:
             nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
         return nn_prediction
-- 
GitLab