Skip to content
Snippets Groups Projects
Commit 994df99b authored by lukas leufen's avatar lukas leufen
Browse files

use observation instead of selection from input data for observation creation

parent b1763fcd
No related branches found
No related tags found
2 merge requests!59Develop,!54Lukas issue061 refac seperate input target vars
Pipeline #30907 passed
...@@ -168,7 +168,7 @@ class PostProcessing(RunEnvironment): ...@@ -168,7 +168,7 @@ class PostProcessing(RunEnvironment):
nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised) nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised)
# persistence # persistence
persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std, persistence_prediction = self._create_persistence_forecast(data, persistence_prediction, mean, std,
transformation_method, normalised) transformation_method, normalised)
# ols # ols
...@@ -197,7 +197,7 @@ class PostProcessing(RunEnvironment): ...@@ -197,7 +197,7 @@ class PostProcessing(RunEnvironment):
@staticmethod @staticmethod
def _create_observation(data, _, mean, std, transformation_method, normalised): def _create_observation(data, _, mean, std, transformation_method, normalised):
obs = data.label.copy() obs = data.observation.copy()
if not normalised: if not normalised:
obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method) obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method)
return obs return obs
...@@ -211,8 +211,8 @@ class PostProcessing(RunEnvironment): ...@@ -211,8 +211,8 @@ class PostProcessing(RunEnvironment):
ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
return ols_prediction return ols_prediction
def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method, normalised): def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised):
tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var}) tmp_persi = data.observation.copy().sel({'window': 0})
if not normalised: if not normalised:
tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
...@@ -295,7 +295,7 @@ class PostProcessing(RunEnvironment): ...@@ -295,7 +295,7 @@ class PostProcessing(RunEnvironment):
try: try:
data = self.train_val_data.get_data_generator(station) data = self.train_val_data.get_data_generator(station)
mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
external_data = self._create_observation(data, None, mean, std, transformation_method) external_data = self._create_observation(data, None, mean, std, transformation_method, normalised=False)
external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"]) external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
return external_data.rename({'datetime': 'index'}) return external_data.rename({'datetime': 'index'})
except KeyError: except KeyError:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment