From f68ecb48b2432a039ed12da5a40a327b0addcfb2 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 8 Jul 2021 16:25:30 +0200
Subject: [PATCH] /close #306 on pipeline success

---
 mlair/run_modules/post_processing.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index 89a6f205..0d7bfeb4 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -85,6 +85,7 @@ class PostProcessing(RunEnvironment):
         self.competitor_path = self.data_store.get("competitor_path")
         self.competitors = to_list(self.data_store.get_default("competitors", default=[]))
         self.forecast_indicator = "nn"
+        self.ahead_dim = "ahead"
         self._run()
 
     def _run(self):
@@ -172,7 +173,7 @@ class PostProcessing(RunEnvironment):
             bootstrap_path = self.data_store.get("bootstrap_path")
             forecast_path = self.data_store.get("forecast_path")
             number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
-            dims = ["index", "ahead", "type"]
+            dims = ["index", self.ahead_dim, "type"]
             for station in self.test_data:
                 logging.info(str(station))
                 X, Y = None, None
@@ -467,7 +468,8 @@ class PostProcessing(RunEnvironment):
                                    "obs": observation,
                                    "ols": ols_prediction}
                 all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]),
-                                                              time_dimension, **prediction_dict)
+                                                              time_dimension, ahead_dim=self.ahead_dim,
+                                                              **prediction_dict)
 
                 # save all forecasts locally
                 path = self.data_store.get("forecast_path")
@@ -618,7 +620,8 @@ class PostProcessing(RunEnvironment):
         return index
 
     @staticmethod
-    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs):
+    def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension,
+                               ahead_dim="ahead", **kwargs):
         """
         Combine different forecast types into single xarray.
 
@@ -631,7 +634,7 @@ class PostProcessing(RunEnvironment):
         """
         keys = list(kwargs.keys())
         res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
-                           coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type'])
+                           coords=[index.index, ahead_names, keys], dims=['index', ahead_dim, 'type'])
         for k, v in kwargs.items():
             intersection = set(res.index.values) & set(v.indexes[time_dimension].values)
             match_index = np.array(list(intersection))
-- 
GitLab