Skip to content
Snippets Groups Projects

remove model.prediction from normalized loop

1 file
+ 14
11
Compare changes
  • Side-by-side
  • Inline
@@ -8,6 +8,7 @@ import logging
import os
import sys
import traceback
import copy
from typing import Dict, Tuple, Union, List, Callable
import numpy as np
@@ -726,13 +727,15 @@ class PostProcessing(RunEnvironment):
# get scaling parameters
transformation_func = data.apply_transformation
nn_output = self.model.predict(input_data)
for normalised in [True, False]:
# create empty arrays
nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
target_data, count=4)
# nn forecast
nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised)
nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised)
# persistence
persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
@@ -859,7 +862,7 @@ class PostProcessing(RunEnvironment):
persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True)
return persistence_prediction
def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
def _create_nn_forecast(self, nn_output: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
normalised: bool) -> xr.DataArray:
"""
Create NN forecast for given input data.
@@ -868,22 +871,22 @@ class PostProcessing(RunEnvironment):
output of the main branch is returned (not all minor branches, if the network has multiple output branches). The
main branch is defined to be the last entry of all outputs.
:param input_data: transposed history from DataPrep
:param nn_output: Full NN model output
:param nn_prediction: empty array in right shape to fill with data
:param transformation_func: a callable function to apply inverse transformation
:param normalised: transform prediction in original space if false, or use normalised predictions if true
:return: filled data array with nn predictions
"""
tmp_nn = self.model.predict(input_data)
if isinstance(tmp_nn, list):
nn_prediction.values = tmp_nn[-1]
elif tmp_nn.ndim == 3:
nn_prediction.values = tmp_nn[-1, ...]
elif tmp_nn.ndim == 2:
nn_prediction.values = tmp_nn
if isinstance(nn_output, list):
nn_prediction.values = nn_output[-1]
elif nn_output.ndim == 3:
nn_prediction.values = nn_output[-1, ...]
elif nn_output.ndim == 2:
nn_prediction.values = nn_output
else:
raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {nn_output.dims}.")
if not normalised:
nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
return nn_prediction
Loading