Skip to content
Snippets Groups Projects
Commit f68ecb48 authored by leufen1's avatar leufen1
Browse files

/close #306 on pipeline success

parent 85a74b58
No related branches found
No related tags found
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!295Resolve "data handler FIR filter",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #72609 passed
......@@ -85,6 +85,7 @@ class PostProcessing(RunEnvironment):
self.competitor_path = self.data_store.get("competitor_path")
self.competitors = to_list(self.data_store.get_default("competitors", default=[]))
self.forecast_indicator = "nn"
self.ahead_dim = "ahead"
self._run()
def _run(self):
......@@ -172,7 +173,7 @@ class PostProcessing(RunEnvironment):
bootstrap_path = self.data_store.get("bootstrap_path")
forecast_path = self.data_store.get("forecast_path")
number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing")
dims = ["index", "ahead", "type"]
dims = ["index", self.ahead_dim, "type"]
for station in self.test_data:
logging.info(str(station))
X, Y = None, None
......@@ -467,7 +468,8 @@ class PostProcessing(RunEnvironment):
"obs": observation,
"ols": ols_prediction}
all_predictions = self.create_forecast_arrays(full_index, list(target_data.indexes[window_dim]),
time_dimension, **prediction_dict)
time_dimension, ahead_dim=self.ahead_dim,
**prediction_dict)
# save all forecasts locally
path = self.data_store.get("forecast_path")
......@@ -618,7 +620,8 @@ class PostProcessing(RunEnvironment):
return index
@staticmethod
def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, **kwargs):
def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension,
ahead_dim="ahead", **kwargs):
"""
Combine different forecast types into single xarray.
......@@ -631,7 +634,7 @@ class PostProcessing(RunEnvironment):
"""
keys = list(kwargs.keys())
res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan),
coords=[index.index, ahead_names, keys], dims=['index', 'ahead', 'type'])
coords=[index.index, ahead_names, keys], dims=['index', ahead_dim, 'type'])
for k, v in kwargs.items():
intersection = set(res.index.values) & set(v.indexes[time_dimension].values)
match_index = np.array(list(intersection))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment