diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 10cde645eed44e440f5687d6b5498b5add3ea98d..00c925031e8e8bc804979c75052355837f5cb614 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -698,12 +698,14 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): """ - def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="NN", sampling="daily"): + def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="NN", sampling="daily", + model_name_for_plots=None): """Initialise.""" super().__init__(plot_folder, f"skill_score_competitive_{model_setup}") self._model_setup = model_setup self._sampling = self._get_sampling(sampling) self._labels = None + self._model_name_for_plots = model_name_for_plots self._data = self._prepare_data(data) default_plot_name = self.plot_name # draw full detail plot @@ -745,6 +747,8 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): size = max([len(np.unique(self._data.comparison)), 6]) fig, ax = plt.subplots(figsize=(size, size * 0.8)) data = self._filter_comparisons(self._data) if single_model_comparison is True else self._data + if self._model_name_for_plots is not None: + data['comparison'] = [i.replace('nn-', f'{self._model_name_for_plots}-') for i in data['comparison']] order = self._create_pseudo_order(data) sns.boxplot(x="comparison", y="data", hue="ahead", data=data, whis=1., ax=ax, palette="Blues_d", showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, @@ -761,6 +765,8 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): """Plot skill scores of the comparisons, but vertically aligned.""" fig, ax = plt.subplots() data = self._filter_comparisons(self._data) if single_model_comparison is True else self._data + if self._model_name_for_plots is not None: + data['comparison'] = [i.replace('nn-', f'{self._model_name_for_plots}-') for i in data['comparison']] order = self._create_pseudo_order(data) sns.boxplot(y="comparison", x="data", hue="ahead", data=data, whis=1., ax=ax, palette="Blues_d", showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, @@ -780,7 +786,8 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): return uniq[index.argsort()] def _filter_comparisons(self, data): - filtered_headers = list(filter(lambda x: "nn-" in x, data.comparison.unique())) + # filtered_headers = list(filter(lambda x: "nn-" in x, data.comparison.unique())) + filtered_headers = list(filter(lambda x: f"{self._model_name_for_plots}-" in x, data.comparison.unique())) return data[data.comparison.isin(filtered_headers)] def _lim(self) -> Tuple[float, float]: @@ -906,14 +913,19 @@ class PlotBootstrapSkillScore(AbstractPlotClass): data_second = self._select_data(df=data, variables=remaining_vars, column_name='boot_var') order_first = self.set_order_for_x_axis(separate_vars) - order_second = self.set_order_for_x_axis(remaining_vars) order_second, center_names_second = self.set_order_for_x_axis(remaining_vars, return_center_names=True) number_of_vars_second = len(order_second) group_size = int(number_of_vars_second / len(center_names_second)) + if len(self._individual_vars) > 20: + figsize = (len(self._individual_vars) / 2, 10) + else: + figsize = (15, 10) + + fig, ax = plt.subplots(nrows=1, ncols=2, - figsize=(len(self._individual_vars) / 2, 10), + figsize=figsize, gridspec_kw={'width_ratios': [len(separate_vars), len(remaining_vars) ] @@ -1004,7 +1016,7 @@ class PlotBootstrapSkillScore(AbstractPlotClass): if number_of_vars > 20: fig, ax = plt.subplots(figsize=(number_of_vars/2, 10)) else: - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=(15, 10)) sns.boxplot(x=self._x_name, y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}, order=order) diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py index 0d80fe7431347c48134ed1a300b93cf8d3e33195..d5033cf70a4a7d13c253843ae627360c0b2596d4 100644 --- a/mlair/run_modules/post_processing.py +++ b/mlair/run_modules/post_processing.py @@ -84,6 +84,7 @@ class PostProcessing(RunEnvironment): self.competitor_path = self.data_store.get("competitor_path") self.competitors = to_list(self.data_store.get_default("competitors", default=[])) self.forecast_indicator = "nn" + self.model_name_for_plots = self.data_store.get_default("model_name_for_plots", default=None) self._run() def _run(self): @@ -363,7 +364,8 @@ class PostProcessing(RunEnvironment): try: if "PlotCompetitiveSkillScore" in plot_list: PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, - model_setup=self.forecast_indicator, sampling=self._sampling) + model_setup=self.forecast_indicator, sampling=self._sampling, + model_name_for_plots=self.model_name_for_plots) except Exception as e: logging.error(f"Could not create plot PlotCompetitiveSkillScore due to the following error: {e}")