From 994df99bc4513dafcb730d88657e8a9daa6edf49 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 2 Mar 2020 12:40:39 +0100
Subject: [PATCH] use observation instead of selection from input data for
 observation creation

---
 src/run_modules/post_processing.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 5c392a40..962c9f52 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -168,7 +168,7 @@ class PostProcessing(RunEnvironment):
                 nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised)
 
                 # persistence
-                persistence_prediction = self._create_persistence_forecast(input_data, persistence_prediction, mean, std,
+                persistence_prediction = self._create_persistence_forecast(data, persistence_prediction, mean, std,
                                                                            transformation_method, normalised)
 
                 # ols
@@ -197,7 +197,7 @@ class PostProcessing(RunEnvironment):
 
     @staticmethod
     def _create_observation(data, _, mean, std, transformation_method, normalised):
-        obs = data.label.copy()
+        obs = data.observation.copy()
         if not normalised:
             obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method)
         return obs
@@ -211,8 +211,8 @@ class PostProcessing(RunEnvironment):
         ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols
         return ols_prediction
 
-    def _create_persistence_forecast(self, input_data, persistence_prediction, mean, std, transformation_method, normalised):
-        tmp_persi = input_data.sel({'window': 0, 'variables': self.target_var})
+    def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised):
+        tmp_persi = data.observation.copy().sel({'window': 0})
         if not normalised:
             tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method)
         window_lead_time = self.data_store.get("window_lead_time", "general")
@@ -295,7 +295,7 @@ class PostProcessing(RunEnvironment):
         try:
             data = self.train_val_data.get_data_generator(station)
             mean, std, transformation_method = data.get_transformation_information(variable=self.target_var)
-            external_data = self._create_observation(data, None, mean, std, transformation_method)
+            external_data = self._create_observation(data, None, mean, std, transformation_method, normalised=False)
             external_data = external_data.squeeze("Stations").sel(window=1).drop(["window", "Stations", "variables"])
             return external_data.rename({'datetime': 'index'})
         except KeyError:
-- 
GitLab