Skip to content
Snippets Groups Projects
Commit 972d44e2 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

include ens preds with full time index - incl. missing values

parent 1ff4df7c
Branches
No related tags found
2 merge requests!474Draft: Resolve "DataHandler with multiple stats per variable",!466Draft: Resolve "Include CRPS analysis and other ens verif methods or plots"
Pipeline #108516 passed
...@@ -11,6 +11,7 @@ import sys ...@@ -11,6 +11,7 @@ import sys
import traceback import traceback
import copy import copy
from typing import Dict, Tuple, Union, List, Callable from typing import Dict, Tuple, Union, List, Callable
import ensverif
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -89,6 +90,7 @@ class PostProcessing(RunEnvironment): ...@@ -89,6 +90,7 @@ class PostProcessing(RunEnvironment):
self.num_realizations = self.data_store.get("num_realizations", "postprocessing") self.num_realizations = self.data_store.get("num_realizations", "postprocessing")
self.ens_realization_dim = self.data_store.get("ens_realization_dim", "postprocessing") self.ens_realization_dim = self.data_store.get("ens_realization_dim", "postprocessing")
self.ens_moment_dim = self.data_store.get("ens_moment_dim", "postprocessing") self.ens_moment_dim = self.data_store.get("ens_moment_dim", "postprocessing")
self.iter_dim = self.data_store.get("iter_dim")
self.window_lead_time = extract_value(self.data_store.get("output_shape", "model")) self.window_lead_time = extract_value(self.data_store.get("output_shape", "model"))
self.skill_scores = None self.skill_scores = None
self.feature_importance_skill_scores = None self.feature_importance_skill_scores = None
...@@ -786,7 +788,23 @@ class PostProcessing(RunEnvironment): ...@@ -786,7 +788,23 @@ class PostProcessing(RunEnvironment):
}) })
nn_ens_dist_predictions = self._create_nn_ens_forecast(ens_collector, nn_ens_dist_prediction, nn_ens_dist_predictions = self._create_nn_ens_forecast(ens_collector, nn_ens_dist_prediction,
transformation_func, normalised) transformation_func, normalised)
all_predictions_ens = xr.Dataset({"ens": nn_ens_dist_predictions,
nn_ens_dist_predictions_full = self.create_forecast_arrays(
full_index, list(target_data.indexes[window_dim]), time_dimension,
ahead_dim=self.ahead_dim,
index_dim=self.index_dim, type_dim=self.model_type_dim,
ens_dims=[
self.ens_realization_dim,
self.ens_moment_dim],
ens_coords=[
range(self.num_realizations),
["ens_dist_mean", "ens_dist_stddev"]],
**{"ens": nn_ens_dist_predictions.transpose("datetime", ...)}
)
nn_ens_dist_predictions_full = nn_ens_dist_predictions_full.expand_dims(
{self.iter_dim: to_list(str(nn_ens_dist_predictions[self.iter_dim].values))}
).transpose(self.index_dim, ...)
all_predictions_ens = xr.Dataset({"ens": nn_ens_dist_predictions_full,
"det": all_predictions, "det": all_predictions,
}) })
file_ens = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_ens_{subset_type}") file_ens = os.path.join(self.forecast_path, f"{prefix}_{str(data)}_ens_{subset_type}")
...@@ -794,8 +812,6 @@ class PostProcessing(RunEnvironment): ...@@ -794,8 +812,6 @@ class PostProcessing(RunEnvironment):
with open(f"{file_ens}_dist.pkl", 'wb') as outp: with open(f"{file_ens}_dist.pkl", 'wb') as outp:
pickle.dump(ens_collector, outp, pickle.HIGHEST_PROTOCOL) pickle.dump(ens_collector, outp, pickle.HIGHEST_PROTOCOL)
@staticmethod @staticmethod
def _create_ens_mean_pred(collector): def _create_ens_mean_pred(collector):
"""Calculates the ens. mean from a list containing ens. members of type tfp.distributions._TensorCoercible""" """Calculates the ens. mean from a list containing ens. members of type tfp.distributions._TensorCoercible"""
...@@ -1002,7 +1018,8 @@ class PostProcessing(RunEnvironment): ...@@ -1002,7 +1018,8 @@ class PostProcessing(RunEnvironment):
@staticmethod @staticmethod
def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension, def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], time_dimension,
ahead_dim="ahead", index_dim="index", type_dim="type", **kwargs): ahead_dim="ahead", index_dim="index", type_dim="type",
ens_coords = None, ens_dims=None, **kwargs):
""" """
Combine different forecast types into single xarray. Combine different forecast types into single xarray.
...@@ -1015,12 +1032,22 @@ class PostProcessing(RunEnvironment): ...@@ -1015,12 +1032,22 @@ class PostProcessing(RunEnvironment):
""" """
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
keys = list(kwargs.keys()) keys = list(kwargs.keys())
res = xr.DataArray(np.full((len(index.index), len(ahead_names), len(keys)), np.nan), res_coords = [index.index, ahead_names, keys]
coords=[index.index, ahead_names, keys], dims=[index_dim, ahead_dim, type_dim]) res_dims = [index_dim, ahead_dim, type_dim]
res_fill_shape = (len(index.index), len(ahead_names), len(keys))
if (ens_coords is not None) and (ens_dims is not None):
ens_coords = to_list(ens_coords)
ens_dims = to_list(ens_dims)
res_coords = to_list(res_coords[0]) + ens_coords + to_list(res_coords[1:])
res_dims = to_list(res_dims[0]) + to_list(ens_dims) + to_list(res_dims[1:])
res_fill_shape = [len(i) for i in res_coords]
res = xr.DataArray(np.full(res_fill_shape, np.nan),
coords=res_coords, dims=res_dims)
for k, v in kwargs.items(): for k, v in kwargs.items():
intersection = set(res.index.values) & set(v.indexes[time_dimension].values) intersection = set(res.index.values) & set(v.indexes[time_dimension].values)
match_index = np.array(list(intersection)) match_index = np.array(list(intersection))
res.loc[match_index, :, k] = v.loc[match_index] res.loc[match_index, ..., k] = v.loc[match_index]
return res return res
def _get_internal_data(self, station: str, path: str) -> Union[xr.DataArray, None]: def _get_internal_data(self, station: str, path: str) -> Union[xr.DataArray, None]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment