diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 1500cdab23fca3058bffc838b6855fd0b3455f3d..a48a82b25804da34d07807073b0c153408e4e028 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -8,6 +8,7 @@ import logging import os import sys import traceback +import copy from typing import Dict, Tuple, Union, List, Callable import numpy as np @@ -726,13 +727,15 @@ class PostProcessing(RunEnvironment): # get scaling parameters transformation_func = data.apply_transformation + nn_output = self.model.predict(input_data) + for normalised in [True, False]: # create empty arrays nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays( target_data, count=4) # nn forecast - nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised) + nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised) # persistence persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, @@ -859,7 +862,7 @@ class PostProcessing(RunEnvironment): persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True) return persistence_prediction - def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable, + def _create_nn_forecast(self, nn_output: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable, normalised: bool) -> xr.DataArray: """ Create NN forecast for given input data. @@ -868,22 +871,22 @@ class PostProcessing(RunEnvironment): output of the main branch is returned (not all minor branches, if the network has multiple output branches). The main branch is defined to be the last entry of all outputs. - :param input_data: transposed history from DataPrep + :param nn_output: Full NN model output :param nn_prediction: empty array in right shape to fill with data :param transformation_func: a callable function to apply inverse transformation :param normalised: transform prediction in original space if false, or use normalised predictions if true :return: filled data array with nn predictions """ - tmp_nn = self.model.predict(input_data) - if isinstance(tmp_nn, list): - nn_prediction.values = tmp_nn[-1] - elif tmp_nn.ndim == 3: - nn_prediction.values = tmp_nn[-1, ...] - elif tmp_nn.ndim == 2: - nn_prediction.values = tmp_nn + + if isinstance(nn_output, list): + nn_prediction.values = nn_output[-1] + elif nn_output.ndim == 3: + nn_prediction.values = nn_output[-1, ...] + elif nn_output.ndim == 2: + nn_prediction.values = nn_output else: - raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.") + raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {nn_output.dims}.") if not normalised: nn_prediction = transformation_func(nn_prediction, base="target", inverse=True) return nn_prediction