Skip to content
Snippets Groups Projects
Commit 3bdf3ea0 authored by lukas leufen's avatar lukas leufen
Browse files

fixed errors with branched models

parent 9ed8db81
No related branches found
No related tags found
2 merge requests!59Develop,!57Lukas issue 064 bug check time axis
Pipeline #31230 passed
...@@ -65,6 +65,8 @@ class PostProcessing(RunEnvironment): ...@@ -65,6 +65,8 @@ class PostProcessing(RunEnvironment):
with TimeTracking(name="boot predictions"): with TimeTracking(name="boot predictions"):
bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
steps=bootstraps.get_boot_strap_generator_length()) steps=bootstraps.get_boot_strap_generator_length())
if isinstance(bootstrap_predictions, list):
bootstrap_predictions = bootstrap_predictions[-1]
bootstrap_meta = np.array(bootstraps.get_boot_strap_meta()) bootstrap_meta = np.array(bootstraps.get_boot_strap_meta())
variables = np.unique(bootstrap_meta[:, 0]) variables = np.unique(bootstrap_meta[:, 0])
for station in np.unique(bootstrap_meta[:, 1]): for station in np.unique(bootstrap_meta[:, 1]):
...@@ -211,7 +213,7 @@ class PostProcessing(RunEnvironment): ...@@ -211,7 +213,7 @@ class PostProcessing(RunEnvironment):
return ols_prediction return ols_prediction
def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised): def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised):
tmp_persi = data.observation.copy().sel({'window': 0})#.shift(datetime=1) tmp_persi = data.observation.copy().sel({'window': 0})
if not normalised: if not normalised:
tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
...@@ -234,7 +236,9 @@ class PostProcessing(RunEnvironment): ...@@ -234,7 +236,9 @@ class PostProcessing(RunEnvironment):
tmp_nn = self.model.predict(input_data) tmp_nn = self.model.predict(input_data)
if not normalised: if not normalised:
tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method) tmp_nn = statistics.apply_inverse_transformation(tmp_nn, mean, std, transformation_method)
if tmp_nn.ndim == 3: if isinstance(tmp_nn, list):
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1], axis=1), 2, 0)
elif tmp_nn.ndim == 3:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0) nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn[-1, ...], axis=1), 2, 0)
elif tmp_nn.ndim == 2: elif tmp_nn.ndim == 2:
nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0) nn_prediction.values = np.swapaxes(np.expand_dims(tmp_nn, axis=1), 2, 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment