Skip to content
Snippets Groups Projects
Commit bf0f2d3c authored by lukas leufen's avatar lukas leufen
Browse files

applied bug fix for #6, program can handle now networks with single or...

applied bug fix for #6, program can handle now networks with single or multiple output branches in the right manner
parent 7a9b091d
No related branches found
No related tags found
2 merge requests!37include new development,!25fixed bug: make prediction with correct dims
Pipeline #28446 passed
...@@ -15,7 +15,8 @@ from src.run_modules.run_environment import RunEnvironment ...@@ -15,7 +15,8 @@ from src.run_modules.run_environment import RunEnvironment
from src.helpers import l_p_loss, LearningRateDecay from src.helpers import l_p_loss, LearningRateDecay
from src.model_modules.inception_model import InceptionModelBase from src.model_modules.inception_model import InceptionModelBase
from src.model_modules.flatten import flatten_tail from src.model_modules.flatten import flatten_tail
from src.model_modules.model_class import MyLittleModel # from src.model_modules.model_class import MyBranchedModel as MyModel
from src.model_modules.model_class import MyLittleModel as MyModel
class ModelSetup(RunEnvironment): class ModelSetup(RunEnvironment):
...@@ -76,7 +77,7 @@ class ModelSetup(RunEnvironment): ...@@ -76,7 +77,7 @@ class ModelSetup(RunEnvironment):
def build_model(self): def build_model(self):
args_list = ["window_history_size", "window_lead_time", "channels"] args_list = ["window_history_size", "window_lead_time", "channels"]
args = self.data_store.create_args_dict(args_list, self.scope) args = self.data_store.create_args_dict(args_list, self.scope)
self.model = MyLittleModel(**args) self.model = MyModel(**args)
self.get_model_settings() self.get_model_settings()
def get_model_settings(self): def get_model_settings(self):
......
...@@ -109,9 +109,25 @@ class PostProcessing(RunEnvironment): ...@@ -109,9 +109,25 @@ class PostProcessing(RunEnvironment):
return persistence_prediction return persistence_prediction
def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method): def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method):
"""
create the nn forecast for given input data. Inverse transformation is applied to the forecast to get the output
in the original space. Furthermore, only the output of the main branch is returned (not all minor branches, if
the network has multiple output branches). The main branch is defined to be the last entry of all outputs.
:param input_data:
:param nn_prediction:
:param mean:
:param std:
:param transformation_method:
:return:
"""
tmp_nn = self.model.predict(input_data) tmp_nn = self.model.predict(input_data)
tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method) tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
if tmp_nn.ndim == 3:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
elif tmp_nn.ndim == 2:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0) nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
else:
raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.")
return nn_prediction return nn_prediction
@staticmethod @staticmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment