diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index a4bd68fcfb66692cb6eb6ba7f1801ae65068e25b..3f11f52c866483b592b6ffcc53becd1b6b2caf4e 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -32,7 +32,7 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING) class PlotContingency(AbstractPlotClass): def __init__(self, station_names, file_path, comp_path, file_name, plot_folder: str = ".", model_name: str = "nn", - obs_name: str = "obs", comp_names: str = "IntelliO3", + model_plot_name: str = "nn", obs_name: str = "obs", comp_names: str = "IntelliO3", plot_names=["contingency_threat_score", "contingency_hit_rate", "contingency_false_alarm_rate", "contingency_bias", "contingency_all_scores", "contingency_table"]): @@ -43,6 +43,7 @@ class PlotContingency(AbstractPlotClass): self._file_name = file_name self._obs_name = obs_name self._model_name = model_name + self._model_plot_name = model_plot_name self._comp_names = to_list(comp_names) self._all_names = [self._model_name] self._all_names.extend(self._comp_names) @@ -88,7 +89,11 @@ class PlotContingency(AbstractPlotClass): plt.plot(range(self._min_threshold, self._max_threshold), data.loc[dict(type="nn", scores=score_name)], label=score_name) else: for type in data.type.values.tolist(): - plt.plot(range(self._min_threshold, self._max_threshold), data.loc[dict(type=type, scores=score)], label=type) + if type in "nn": + plt.plot(range(self._min_threshold, self._max_threshold), data.loc[dict(type=type, scores=score)], + label=self._model_plot_name) + else: + plt.plot(range(self._min_threshold, self._max_threshold), data.loc[dict(type=type, scores=score)], label=type) plt.title(self._plot_names[self._plot_counter]) plt.legend() self.plot_name = self._plot_names[self._plot_counter]