From f68ecb48b2432a039ed12da5a40a327b0addcfb2 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Thu, 8 Jul 2021 16:25:30 +0200 Subject: [PATCH] /close #306 on pipeline success --- mlair/run_modules/post_processing.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 89a6f205..0d7bfeb4 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -85,6 +85,7 @@ class PostProcessing(RunEnvironment): self.competitor_path = self.data_store.get("competitor_path") self.competitors = to_list(self.data_store.get_default("competitors", default=[])) self.forecast_indicator = "nn" + self.ahead_dim = "ahead" self._run() def _run(self): @@ -172,7 +173,7 @@ class PostProcessing(RunEnvironment): bootstrap_path = self.data_store.get("bootstrap_path") forecast_path = self.data_store.get("forecast_path") number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing") - dims = ["index", "ahead", "type"] + dims = ["index", self.ahead_dim, "type"] for station in self.test_data: logging.info(str(station)) X, Y = None, None @@ -467,7 +468,8 @@ class PostProcessing(RunEnvironment): "obs": observation, "ols": ols_prediction} all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]), - time_dimension, **prediction_dict) + time_dimension, ahead_dim=self.ahead_dim, + **prediction_dict) # save all forecasts locally path = self.data_store.get("forecast_path") @@ -618,7 +620,8 @@ class PostProcessing(RunEnvironment): return index @staticmethod - def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs): + def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, + ahead_dim="ahead", **kwargs): """ Combine different forecast types into single xarray. @@ -631,7 +634,7 @@ class PostProcessing(RunEnvironment): """ keys = list(kwargs.keys()) res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan), - coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type']) + coords=[index.index, ahead_names, keys], dims=['index', ahead_dim, 'type']) for k, v in kwargs.items(): intersection = set(res.index.values) & set(v.indexes[time_dimension].values) match_index = np.array(list(intersection)) -- GitLab