diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 5c392a402da47251c51668e0b06a3067104a61e6..962c9f52065729381ce11e8a8adcbeed45a4c011 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -168,7 +168,7 @@ class PostProcessing(RunEnvironment): nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised) # persistence - persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std, + persistence_prediction = self._create_persistence_forecast(data, persistence_prediction, mean, std, transformation_method, normalised) # ols @@ -197,7 +197,7 @@ class PostProcessing(RunEnvironment): @staticmethod def _create_observation(data, _, mean, std, transformation_method, normalised): - obs = data.label.copy() + obs = data.observation.copy() if not normalised: obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method) return obs @@ -211,8 +211,8 @@ class PostProcessing(RunEnvironment): ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols return ols_prediction - def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method, normalised): - tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var}) + def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised): + tmp_persi = data.observation.copy().sel({'window': 0}) if not normalised: tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) window_lead_time = self.data_store.get("window_lead_time", "general") @@ -295,7 +295,7 @@ class PostProcessing(RunEnvironment): try: data = self.train_val_data.get_data_generator(station) mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) - external_data = self._create_observation(data, None, mean, std, transformation_method) + external_data = self._create_observation(data, None, mean, std, transformation_method, normalised=False) external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"]) return external_data.rename({'datetime': 'index'}) except KeyError: