diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 3c8bb04e8891837a5271fd515489f077677e43cc..608050f3def2f7bbc1dc13ffecdaeaf0a39c98c8 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -1155,14 +1155,14 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover width=width, orient=orientation) if orientation == "v": if apply_u_test: - ax = self.set_sigificance_bars_vertical(asteriks, ax, data_table) + ax = self.set_significance_bars(asteriks, ax, data_table, orientation) ylims = list(ax.get_ylim()) ax.set_ylim([ylims[0], ylims[1]*1.025]) ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})") ax.set_xticklabels(ax.get_xticklabels(), rotation=45) elif orientation == "h": if apply_u_test: - ax = self.set_sigificance_bars_horizontal(asteriks, ax, data_table) + ax = self.set_significance_bars(asteriks, ax, data_table, orientation) ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})") xlims = list(ax.get_xlim()) ax.set_xlim([xlims[0], xlims[1] * 1.015]) @@ -1180,35 +1180,25 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover self._save() plt.close("all") - def set_sigificance_bars_vertical(self, asteriks, ax, data_table): - x1 = list(asteriks.index).index(self.model_name) - y_prev = 0. - for i, v in enumerate(asteriks): + def set_significance_bars(self, asteriks, ax, data_table, orientation): + p1 = list(asteriks.index).index(self.model_name) + q_prev = 0. + factor = 0.025 + for i, ast in enumerate(asteriks): if not i == list(asteriks.index).index(self.model_name): - x2 = i - y = data_table[[self.model_name, data_table.columns[i]]].max().max() - y = max(y, y_prev) * 1.025 - if abs(y-y_prev) < y * 0.025: - y = y * 1.025 - h = .01 * data_table.max().max() - ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], c="k") - ax.text((x1 + x2) * .5, y + h, v, ha="center", va="bottom", color="k") - y_prev = y - return ax - - def set_sigificance_bars_horizontal(self, asteriks, ax, data_table): - y1 = list(asteriks.index).index(self.model_name) - x_prev = 0. - for i, v in enumerate(asteriks): - if not i == list(asteriks.index).index(self.model_name): - y2 = i - x = data_table[[self.model_name, data_table.columns[i]]].max().max() - x = max(x, x_prev) * 1.025 - if abs(x-x_prev) < x * 0.025: - x = x * 1.025 - h = .01 * data_table.max().max() - ax.plot([x, x+h, x+h, x], [y1, y1, y2, y2], c="k") - ax.text(x + h, (y1 + y2) * .5, v, ha="left", va="center", color="k", rotation=-90) + p2 = i + q = data_table[[self.model_name, data_table.columns[i]]].max().max() + q = max(q, q_prev) * (1 + factor) + if abs(q - q_prev) < q * factor: + q = q * (1 + factor) + h = 0.01 * data_table.max().max() + if orientation == "h": + ax.plot([q, q + h, q + h, q], [p1, p1, p2, p2], c="k") + ax.text(q + h, (p1 + p2) * 0.5, ast, ha="left", va="center", color="k", rotation=-90) + elif orientation == "v": + ax.plot([p1, p1, p2, p2], [q, q + h, q + h, q], c="k") + ax.text((p1 + p2) * 0.5, q + h, ast, ha="center", va="bottom", color="k") + q_prev = q return ax