diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index a47ef67ad5781ff37ce812aa931dbd195d4513dc..0f3ff6d436b8a65528626f5f80508af222a1e68f 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -15,7 +15,8 @@ from src.run_modules.run_environment import RunEnvironment from src.helpers import l_p_loss, LearningRateDecay from src.model_modules.inception_model import InceptionModelBase from src.model_modules.flatten import flatten_tail -from src.model_modules.model_class import MyLittleModel +# from src.model_modules.model_class import MyBranchedModel as MyModel +from src.model_modules.model_class import MyLittleModel as MyModel class ModelSetup(RunEnvironment): @@ -76,7 +77,7 @@ class ModelSetup(RunEnvironment): def build_model(self): args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) - self.model = MyLittleModel(**args) + self.model = MyModel(**args) self.get_model_settings() def get_model_settings(self): diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 35d93dcbd932d1c298c0744fcd0205697576bb4c..e5739e5f15e1c2f20758e388b3493c28f577bb9a 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -109,9 +109,25 @@ class PostProcessing(RunEnvironment): return persistence_prediction def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method): + """ + create the nn forecast for given input data. Inverse transformation is applied to the forecast to get the output + in the original space. Furthermore, only the output of the main branch is returned (not all minor branches, if + the network has multiple output branches). The main branch is defined to be the last entry of all outputs. + :param input_data: + :param nn_prediction: + :param mean: + :param std: + :param transformation_method: + :return: + """ tmp_nn = self.model.predict(input_data) tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method) - nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0) + if tmp_nn.ndim == 3: + nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0) + elif tmp_nn.ndim == 2: + nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0) + else: + raise NotImplementedError(f"Number of dimension of model output must be 2 or 3, but not {tmp_nn.dims}.") return nn_prediction @staticmethod