diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index d4c25bb8891b88e114ea9244cf308bd810910daf..6a54212ab6969d7145d54988a141c8cd89fdc837 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -212,6 +212,11 @@ def mean_squared_error(a, b, dim=None): return np.square(a - b).mean(dim) +def mean_squared_error_nan(a: xr.DataArray, b: xr.DataArray, dim=None) -> xr.DataArray: + """Calculate mean squared error.""" + return xr.ufuncs.square(a - b).mean(dim, skipna=True) + + def mean_absolute_error(a, b, dim=None): """Calculate mean absolute error.""" return np.abs(a - b).mean(dim) @@ -231,7 +236,7 @@ def skill_score_based_on_mse(data: xr.DataArray, obs_name: str, pred_name: str, obs = data.sel({competitor_dim: obs_name}) pred = data.sel({competitor_dim: pred_name}) ref = data.sel({competitor_dim: ref_name}) - ss = 1 - mean_squared_error(obs, pred, dim=aggregation_dim) / mean_squared_error(obs, ref, dim=aggregation_dim) + ss = 1 - mean_squared_error_nan(obs, pred, dim=aggregation_dim) / mean_squared_error_nan(obs, ref, dim=aggregation_dim) return ss diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 9e3fbdb0b154b6463dd1d03d9ee2b241dd48ca9d..b5c1fcb14a15f0f5b66ea5ad814360c91909c459 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -1006,17 +1006,20 @@ class PostProcessing(RunEnvironment): wind_sectors = self.data_store.get("wind_sectors", "general") h_sector_skill_scores = [] for sec in wind_sectors: + + # prepare ds to drop nans only for time steps which are currently not eq. sec: + # https://stackoverflow.com/questions/52553925/python-xarray-remove-coordinates-with-all-missing-variables + hds = ds.where(self.upstream_wind_sector.squeeze() == sec) + # |__________________ stack all dims without time ________________| |_ drop along index_ | |_unstack_| + hds = hds.stack(z=(self.iter_dim, self.ahead_dim, self.model_type_dim)).dropna(self.index_dim, + how="all").unstack() h_sector_skill_scores.append( - # statistics.SkillScores(None).general_skill_score(ds.where(self.upstream_wind_sector.squeeze() == sec), - # forecast_name=self.model_display_name, - # reference_name=ref_name, - # observation_name=self.observation_indicator) - statistics.skill_score_based_on_mse( - ds.where(self.upstream_wind_sector.squeeze() == sec).dropna(dim=self.index_dim), - obs_name=self.observation_indicator, pred_name=self.model_display_name, - ref_name=ref_name).assign_coords({"sector": sec} - ) - ) + statistics.skill_score_based_on_mse(hds, + obs_name=self.observation_indicator, + pred_name=self.model_display_name, + ref_name=ref_name).assign_coords({"sector": sec} + ) + ) sector_skill_scores = xr.concat(h_sector_skill_scores, dim="sector") sector_skill_scores = sector_skill_scores.assign_attrs({f"reference_model": ref_name}) return sector_skill_scores