Skip to content
Snippets Groups Projects
Commit a37d00bb authored by felix kleinert's avatar felix kleinert
Browse files

Merge branch 'felix_issue413-refac_remove-model-prediction-from-normalized-loop' into 'develop'

remove model.prediction from normalized loop

See merge request !463
parents ea2c08a2 3150d01f
Branches
Tags
5 merge requests!468first implementation of toar-data-v2, can load data (but cannot process these...,!467Resolve "release v2.2.0",!465Merge issue412 (create ens predictions for bnns) into 411 (include postprocessing for bnns),!464Merge Develop into 412,!463remove model.prediction from normalized loop
Pipeline #108814 passed with warnings
...@@ -8,6 +8,7 @@ import logging ...@@ -8,6 +8,7 @@ import logging
import os import os
import sys import sys
import traceback import traceback
import copy
from typing import Dict, Tuple, Union, List, Callable from typing import Dict, Tuple, Union, List, Callable
import numpy as np import numpy as np
...@@ -726,13 +727,15 @@ class PostProcessing(RunEnvironment): ...@@ -726,13 +727,15 @@ class PostProcessing(RunEnvironment):
# get scaling parameters # get scaling parameters
transformation_func = data.apply_transformation transformation_func = data.apply_transformation
nn_output = self.model.predict(input_data)
for normalised in [True, False]: for normalised in [True, False]:
# create empty arrays # create empty arrays
nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays( nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(
target_data, count=4) target_data, count=4)
# nn forecast # nn forecast
nn_prediction = self._create_nn_forecast(input_data, nn_prediction, transformation_func, normalised) nn_prediction = self._create_nn_forecast(copy.deepcopy(nn_output), nn_prediction, transformation_func, normalised)
# persistence # persistence
persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction, persistence_prediction = self._create_persistence_forecast(observation_data, persistence_prediction,
...@@ -859,7 +862,7 @@ class PostProcessing(RunEnvironment): ...@@ -859,7 +862,7 @@ class PostProcessing(RunEnvironment):
persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True) persistence_prediction = transformation_func(persistence_prediction, "target", inverse=True)
return persistence_prediction return persistence_prediction
def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable, def _create_nn_forecast(self, nn_output: xr.DataArray, nn_prediction: xr.DataArray, transformation_func: Callable,
normalised: bool) -> xr.DataArray: normalised: bool) -> xr.DataArray:
""" """
Create NN forecast for given input data. Create NN forecast for given input data.
...@@ -868,22 +871,22 @@ class PostProcessing(RunEnvironment): ...@@ -868,22 +871,22 @@ class PostProcessing(RunEnvironment):
output of the main branch is returned (not all minor branches, if the network has multiple output branches). The output of the main branch is returned (not all minor branches, if the network has multiple output branches). The
main branch is defined to be the last entry of all outputs. main branch is defined to be the last entry of all outputs.
:param input_data: transposed history from DataPrep :param nn_output: Full NN model output
:param nn_prediction: empty array in right shape to fill with data :param nn_prediction: empty array in right shape to fill with data
:param transformation_func: a callable function to apply inverse transformation :param transformation_func: a callable function to apply inverse transformation
:param normalised: transform prediction in original space if false, or use normalised predictions if true :param normalised: transform prediction in original space if false, or use normalised predictions if true
:return: filled data array with nn predictions :return: filled data array with nn predictions
""" """
tmp_nn = self.model.predict(input_data)
if isinstance(tmp_nn, list): if isinstance(nn_output, list):
nn_prediction.values = tmp_nn[-1] nn_prediction.values = nn_output[-1]
elif tmp_nn.ndim == 3: elif nn_output.ndim == 3:
nn_prediction.values = tmp_nn[-1, ...] nn_prediction.values = nn_output[-1, ...]
elif tmp_nn.ndim == 2: elif nn_output.ndim == 2:
nn_prediction.values = tmp_nn nn_prediction.values = nn_output
else: else:
raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.") raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {nn_output.dims}.")
if not normalised: if not normalised:
nn_prediction = transformation_func(nn_prediction, base="target", inverse=True) nn_prediction = transformation_func(nn_prediction, base="target", inverse=True)
return nn_prediction return nn_prediction
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment