diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 748476b814c6c54812df27274f54615bbf08d269..5535114dba1bf95d51ba857daba7ad80c7c1563e 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -640,7 +640,7 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover if "branch" in self._data.columns: plot_name = self.plot_name for branch in self._data["branch"].unique(): - self._set_title(model_name, branch) + self._set_title(model_name, branch, len(self._data["branch"].unique())) self.plot_name = f"{plot_name}_{branch}" try: self._plot(branch=branch) @@ -672,13 +672,17 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover def _set_bootstrap_type(boot_type): return {"singleinput": "single input"}.get(boot_type, boot_type) - def _set_title(self, model_name, branch=None): + def _set_title(self, model_name, branch=None, n_branches=None): title_d = {"single input": "Single Inputs", "branch": "Input Branches", "variable": "Variables"} base_title = f"{model_name}\nImportance of {title_d[self._boot_type]}" additional = [] if branch is not None: - branch_name = self._branches_names[branch] if self._branches_names is not None else branch + try: + assert n_branches == len(self._branches_names) + branch_name = self._branches_names[int(branch)] + except (IndexError, TypeError, ValueError, AssertionError): + branch_name = branch additional.append(branch_name) if self._number_of_bootstraps > 1: additional.append(f"n={self._number_of_bootstraps}") diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index dfcc9edb15b33d8020c094ca81af5332e01782bc..d73f991ec331a8050e91c2ee449f8b28016d48a4 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -358,11 +358,13 @@ class PostProcessing(RunEnvironment): number_of_bootstraps = self.data_store.get("n_boots", "feature_importance") forecast_file = f"forecasts_norm_%s_test.nc" reference_name = "orig" + branch_names = self.data_store.get_default("branch_names", None) bootstraps = Bootstraps(self.test_data[0], number_of_bootstraps, bootstrap_type=bootstrap_type, bootstrap_method=bootstrap_method) number_of_bootstraps = bootstraps.number_of_bootstraps bootstrap_iter = bootstraps.bootstraps() + branch_length = self.get_distinct_branches_from_bootstrap_iter(bootstrap_iter) skill_scores = statistics.SkillScores(None, ahead_dim=self.ahead_dim) score = {} for station in self.test_data: @@ -390,10 +392,11 @@ class PostProcessing(RunEnvironment): boot_scores.append( skill_scores.general_skill_score(data, forecast_name=boot_var, reference_name=reference_name, dim=self.index_dim)) + boot_var_renamed = self.rename_boot_var_with_branch(boot_var, bootstrap_type, branch_names, expected_len=branch_length) tmp = xr.DataArray(np.expand_dims(np.array(boot_scores), axis=-1), coords={self.ahead_dim: range(1, self.window_lead_time + 1), self.uncertainty_estimate_boot_dim: range(number_of_bootstraps), - self.boot_var_dim: [boot_var]}, + self.boot_var_dim: [boot_var_renamed]}, dims=[self.ahead_dim, self.uncertainty_estimate_boot_dim, self.boot_var_dim]) skill.append(tmp) @@ -401,6 +404,31 @@ class PostProcessing(RunEnvironment): score[str(station)] = xr.concat(skill, dim=self.boot_var_dim) return score + @staticmethod + def get_distinct_branches_from_bootstrap_iter(bootstrap_iter): + if isinstance(bootstrap_iter[0], tuple): + return len(set(map(lambda x: x[0], bootstrap_iter))) + else: + return len(bootstrap_iter) + + def rename_boot_var_with_branch(self, boot_var, bootstrap_type, branch_names=None, expected_len=0): + if branch_names is None: + return boot_var + if bootstrap_type == "branch": + try: + assert len(branch_names) > int(boot_var) + assert len(branch_names) == expected_len + return branch_names[int(boot_var)] + except (AssertionError, TypeError): + return boot_var + elif bootstrap_type == "singleinput": + if "_" in boot_var: + branch, other = boot_var.split("_", 1) + branch = self.rename_boot_var_with_branch(branch, "branch", branch_names=branch_names, expected_len=expected_len) + boot_var = "_".join([branch, other]) + return boot_var + return boot_var + def get_orig_prediction(self, path, file_name, prediction_name=None, reference_name=None): if prediction_name is None: prediction_name = self.forecast_indicator @@ -477,6 +505,7 @@ class PostProcessing(RunEnvironment): try: if (self.feature_importance_skill_scores is not None) and ("PlotFeatureImportanceSkillScore" in plot_list): + branch_names = self.data_store.get_default("branch_names", None) for boot_type, boot_data in self.feature_importance_skill_scores.items(): for boot_method, boot_skill_score in boot_data.items(): try: @@ -484,7 +513,7 @@ class PostProcessing(RunEnvironment): boot_skill_score, plot_folder=self.plot_path, model_name=self.model_display_name, sampling=self._sampling, ahead_dim=self.ahead_dim, separate_vars=to_list(self.target_var), bootstrap_type=boot_type, - bootstrap_method=boot_method) + bootstrap_method=boot_method, branch_names=branch_names) except Exception as e: logging.error(f"Could not create plot PlotFeatureImportanceSkillScore ({boot_type}, " f"{boot_method}) due to the following error:\n{sys.exc_info()[0]}\n"