diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index e6d6de152e42d44f271ba986b6645d2cd36b68d0..748476b814c6c54812df27274f54615bbf08d269 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -641,20 +641,32 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover plot_name = self.plot_name for branch in self._data["branch"].unique(): self._set_title(model_name, branch) - self._plot(branch=branch) self.plot_name = f"{plot_name}_{branch}" - self._save() + try: + self._plot(branch=branch) + self._save() + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: self.plot_name += '_separated' - self._plot(branch=branch, separate_vars=separate_vars) - self._save(bbox_inches='tight') + try: + self._plot(branch=branch, separate_vars=separate_vars) + self._save(bbox_inches='tight') + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") else: - self._plot() - self._save() + try: + self._plot() + self._save() + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") if len(set(separate_vars).intersection(self._data[self._x_name].unique())) > 0: self.plot_name += '_separated' - self._plot(separate_vars=separate_vars) - self._save(bbox_inches='tight') + try: + self._plot(separate_vars=separate_vars) + self._save(bbox_inches='tight') + except ValueError as e: + logging.info(f"Did not plot {self.plot_name} because of {e}") @staticmethod def _set_bootstrap_type(boot_type): @@ -696,11 +708,26 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover number_tags = self._get_number_tag(data.coords[self._x_name].values, split_by='_') new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', keep=1, as_unique=True) - values = data.values.reshape((*data.shape[:3], len(number_tags), len(new_boot_coords))) - data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, - "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim], - self._boot_dim: data.coords[self._boot_dim]}, - dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) + try: + values = data.values.reshape((*data.shape[:3], len(number_tags), len(new_boot_coords))) + data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, + "branch": number_tags, self._ahead_dim: data.coords[self._ahead_dim], + self._boot_dim: data.coords[self._boot_dim]}, + dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) + except ValueError: + data_coll = [] + for nr in number_tags: + filtered_coords = list(filter(lambda x: nr in x.split("_")[0], data.coords[self._x_name].values)) + new_boot_coords = self._return_vars_without_number_tag(filtered_coords, split_by='_', keep=1, + as_unique=True) + sel_data = data.sel({self._x_name: filtered_coords}) + values = sel_data.values.reshape((*data.shape[:3], 1, len(new_boot_coords))) + sel_data = xr.DataArray(values, coords={station_dim: data.coords[station_dim], self._x_name: new_boot_coords, + "branch": [nr], self._ahead_dim: data.coords[self._ahead_dim], + self._boot_dim: data.coords[self._boot_dim]}, + dims=[station_dim, self._ahead_dim, self._boot_dim, "branch", self._x_name]) + data_coll.append(sel_data) + data = xr.concat(data_coll, "branch") else: try: new_boot_coords = self._return_vars_without_number_tag(data.coords[self._x_name].values, split_by='_', @@ -713,7 +740,7 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover if station_dim not in data.dims: data = data.expand_dims(station_dim) self._number_of_bootstraps = np.unique(data.coords[self._boot_dim].values).shape[0] - return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()) + return data.to_dataframe("data").reset_index(level=np.arange(len(data.dims)).tolist()).dropna() @staticmethod def _get_target_sampling(sampling, pos): @@ -765,9 +792,10 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover def _plot_selected_variables(self, separate_vars: List, branch=None): data = self._data if branch is None else self._data[self._data["branch"] == str(branch)] - self.raise_error_if_separate_vars_do_not_exist(data, separate_vars, self._x_name) + self.raise_error_if_vars_do_not_exist(data, separate_vars, self._x_name, name="separate_vars") all_variables = self._get_unique_values_from_column_of_df(data, self._x_name) remaining_vars = helpers.remove_items(all_variables, separate_vars) + self.raise_error_if_vars_do_not_exist(data, remaining_vars, self._x_name, name="remaining_vars") data_first = self._select_data(df=data, variables=separate_vars, column_name=self._x_name) data_second = self._select_data(df=data, variables=remaining_vars, column_name=self._x_name) @@ -843,9 +871,13 @@ class PlotFeatureImportanceSkillScore(AbstractPlotClass): # pragma: no cover selected_data = pd.concat([selected_data, tmp_var], axis=0) return selected_data - def raise_error_if_separate_vars_do_not_exist(self, data, separate_vars, column_name): - if not self._variables_exist_in_df(df=data, variables=separate_vars, column_name=column_name): - raise ValueError(f"At least one entry of `separate_vars' does not exist in `self.data' ") + def raise_error_if_vars_do_not_exist(self, data, vars, column_name, name="separate_vars"): + if len(vars) == 0: + msg = f"No variables are given for `{name}' to check in `self.data' " + raise ValueError(msg) + if not self._variables_exist_in_df(df=data, variables=vars, column_name=column_name): + msg = f"At least one entry of `{name}' does not exist in `self.data' " + raise ValueError(msg) @staticmethod def _get_unique_values_from_column_of_df(df: pd.DataFrame, column_name: str) -> List: