diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index 0251f3eab1101ce2c23433ac5f63d8a87dd71a9a..546a463650ccca4c6f7e2b63b3afb01db9d90a40 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -209,7 +209,7 @@ class SkillScores: def skill_scores(self, window_lead_time: int) -> pd.DataFrame: """ - Calculate skill scores for all combinations of CNN, persistence and OLS. + Calculate skill scores for all combinations of model names. :param window_lead_time: length of forecast steps @@ -228,7 +228,7 @@ class SkillScores: return skill_score def climatological_skill_scores(self, external_data: Data, window_lead_time: int, - forecast_name: str = "cnn") -> xr.DataArray: + forecast_name: str) -> xr.DataArray: """ Calculate climatological skill scores according to Murphy (1988). @@ -273,8 +273,8 @@ class SkillScores: kwargs = {"external_data": external_data} if external_data is not None else {} return self.__getattribute__(f"skill_score_mu_case_{mu_type}")(data, observation_name, forecast_name, **kwargs) - @staticmethod - def general_skill_score(data: Data, observation_name: str, forecast_name: str, reference_name: str) -> np.ndarray: + def general_skill_score(self, data: Data, forecast_name: str, reference_name: str, + observation_name: str = None) -> np.ndarray: r""" Calculate general skill score based on mean squared error. @@ -285,6 +285,8 @@ class SkillScores: :return: skill score of forecast """ + if observation_name is None: + observation_name = self.observation_name data = data.dropna("index") observation = data.sel(type=observation_name) forecast = data.sel(type=forecast_name)