diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 59e98c5bd47f7fb1e24fc6f9f7a9dc05359b587e..2c7d8fdedb2720b1276b77a5e817723094ecf03d 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -695,9 +695,10 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): """ - def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="CNN"): + def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="NN"): """Initialise.""" super().__init__(plot_folder, f"skill_score_competitive_{model_setup}") + self._model_setup = model_setup self._labels = None self._data = self._prepare_data(data) self._plot() @@ -754,7 +755,7 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): def _create_pseudo_order(self): """Provide first predefined elements and append all remaining.""" - first_elements = ["cnn-persi", "ols-persi", "cnn-ols"] + first_elements = [f"{self._model_setup}-persi", "ols-persi", f"{self._model_setup}-ols"] uniq, index = np.unique(first_elements + self._data.comparison.unique().tolist(), return_index=True) return uniq[index.argsort()] @@ -1199,6 +1200,7 @@ class PlotAvailability(AbstractPlotClass): def _plot(self, plt_dict): colors = self.get_dataset_colors() + _used_colors = [] pos = 0 height = 0.8 # should be <= 1 yticklabels = [] @@ -1210,13 +1212,15 @@ class PlotAvailability(AbstractPlotClass): plt_data = d.get(subset) if plt_data is None: continue + elif color not in _used_colors: # this is required for a proper legend creation + _used_colors.append(color) ax.broken_barh(plt_data, (pos, height), color=color, edgecolor="white", linewidth=self.linewidth) yticklabels.append(station) ax.set_ylim([height, number_of_stations + 1]) ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) ax.set_yticklabels(yticklabels) - handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()] + handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items() if c in _used_colors] lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) return lgd