From c9ec2f00f3a21b57142b5f16ac39a11454afbcdc Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Tue, 25 Jan 2022 18:56:58 +0100 Subject: [PATCH] frst implementation of mw-utest in plots --- mlair/helpers/statistics.py | 42 +++++++++++ mlair/plotting/postprocessing_plotting.py | 88 ++++++++++++++++++++--- 2 files changed, 120 insertions(+), 10 deletions(-) diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index 19a4893d..d1388a35 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -10,6 +10,7 @@ import xarray as xr import pandas as pd from typing import Union, Tuple, Dict, List import itertools +from collections import OrderedDict Data = Union[xr.DataArray, pd.DataFrame] @@ -219,6 +220,47 @@ def calculate_error_metrics(a, b, dim): return {"mse": mse, "rmse": rmse, "mae": mae, "n": n} +def mann_whitney_u_test(data: pd.DataFrame, reference_col_name: str, **kwargs): + """ + Calculate Mann-Whitney u-test. Uses pandas' .apply() on scipy.stats.mannwhitneyu(x, y, ...). + :param data: + :type data: + :param reference_col_name: Name of column which is used for comparison (y) + :type reference_col_name: + :param kwargs: + :type kwargs: + :return: + :rtype: + """ + res = data.apply(stats.mannwhitneyu, y=data[reference_col_name], **kwargs) + res = res.rename(index={0: "statistics", 1: "p-value"}) + return res + + +def represent_p_values_as_asteriks(p_values: pd.Series, threshold_representation: OrderedDict = None): + """ + Represent p-values as asteriks based on its value. + :param p_values: + :type p_values: + :param threshold_representation: + :type threshold_representation: + :return: + :rtype: + """ + if threshold_representation is None: + threshold_representation = OrderedDict([(1, "ns"), (0.05, "*"), (0.01, "**"), (0.001, "***")]) + + if not all(x > y for x, y in zip(list(threshold_representation.keys()), list(threshold_representation.keys())[1:])): + raise ValueError( + f"`threshold_representation' keys mus be in strictly " + f"decreasing order but is: {threshold_representation.keys()}") + + asteriks = pd.Series().reindex_like(p_values) + for k, v in threshold_representation.items(): + asteriks[p_values < k] = v + return asteriks + + class SkillScores: r""" Calculate different kinds of skill scores. diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py index 748476b8..4799687f 100644 --- a/mlair/plotting/postprocessing_plotting.py +++ b/mlair/plotting/postprocessing_plotting.py @@ -16,11 +16,13 @@ import seaborn as sns import xarray as xr from matplotlib.backends.backend_pdf import PdfPages from matplotlib.offsetbox import AnchoredText +from scipy.stats import mannwhitneyu from mlair import helpers from mlair.data_handler.iterator import DataCollection from mlair.helpers import TimeTrackingWrapper from mlair.plotting.abstract_plot_class import AbstractPlotClass +from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks logging.getLogger('matplotlib').setLevel(logging.WARNING) @@ -1092,24 +1094,51 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover self.dim_name_boots = dim_name_boots self.error_unit = error_unit self.block_length = block_length + self.model_name = model_name data = self.rename_model_indicator(data, model_name, model_indicator) self.prepare_data(data) - self._plot(orientation="v") - self.plot_name = default_name + "_horizontal" - self._plot(orientation="h") + variants = [("_vertical", "v", False), ("_vertical_u_test", "v", True), + ("_horizontal", "h", False), ("_horizontal_u_test", "h", True)] + + for name_tag, orientation, utest in variants: + self.plot_name = default_name + name_tag + self._plot(orientation=orientation, apply_u_test=utest) + + # self._plot(orientation="v", apply_u_test=False) + # + # self.plot_name = default_name + "_u_test" + # self._plot(orientation="v", apply_u_test=True) + # + # self.plot_name = default_name + "_horizontal" + # self._plot(orientation="h", apply_u_test=False) + # + # self.plot_name = default_name + "_horizontal_u_test" + # self._plot(orientation="h", apply_u_test=True) self._apply_root() - self.plot_name = default_name + "_sqrt" - self._plot(orientation="v") + variants = [(tag+"_sqrt", ori, ut) for tag, ori, ut in variants] + + for name_tag, orientation, utest in variants: + self.plot_name = default_name + name_tag + self._plot(orientation=orientation, apply_u_test=utest) - self.plot_name = default_name + "_horizontal_sqrt" - self._plot(orientation="h") + # self.plot_name = default_name + "_sqrt" + # self._plot(orientation="v") + # + # self.plot_name = default_name + "_horizontal_sqrt" + # self._plot(orientation="h") self._data_table = None self._n_boots = None + @property + def get_asteriks_from_mann_whitney_u_result(self): + return represent_p_values_as_asteriks(mann_whitney_u_test(data=self._data_table, + reference_col_name=self.model_name, + axis=0, alternative="two-sided").iloc[-1]) + def rename_model_indicator(self, data, model_name, model_indicator): data.coords[self.model_type_dim] = [{model_indicator: model_name}.get(n, n) for n in data.coords[self.model_type_dim].values] @@ -1125,10 +1154,11 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover self.error_measure = f"root {self.error_measure}" self.error_unit = self.error_unit.replace("$^2$", "") - def _plot(self, orientation: str = "v"): + def _plot(self, orientation: str = "v", apply_u_test: bool = False): data_table = self._data_table n_boots = self._n_boots size = len(np.unique(data_table.columns)) + asteriks = self.get_asteriks_from_mann_whitney_u_result if orientation == "v": figsize, width = (size, 5), 0.4 elif orientation == "h": @@ -1142,15 +1172,26 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover boxprops={'facecolor': 'none', 'edgecolor': 'k'}, width=width, orient=orientation) if orientation == "v": + if apply_u_test: + ax = self.set_sigificance_bars_vertical(asteriks, ax, data_table) + ylims = list(ax.get_ylim()) + ax.set_ylim([ylims[0], ylims[1]*1.025]) ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})") ax.set_xticklabels(ax.get_xticklabels(), rotation=45) elif orientation == "h": + if apply_u_test: + ax = self.set_sigificance_bars_horizontal(asteriks, ax, data_table) ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})") + xlims = list(ax.get_xlim()) + ax.set_xlim([xlims[0], xlims[1] * 1.015]) + else: raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}" - loc = "upper right" if orientation == "h" else "upper left" - text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5) + # loc = "upper right" if orientation == "h" else "upper left" + loc = "lower left" + text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0), + bbox_transform=ax.transAxes) plt.setp(text_box.patch, edgecolor='k', facecolor='w') ax.add_artist(text_box) plt.setp(ax.lines, color='k') @@ -1158,6 +1199,33 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover self._save() plt.close("all") + def set_sigificance_bars_vertical(self, asteriks, ax, data_table): + x1 = list(asteriks.index).index(self.model_name) + y_prev = 0. + for i, v in enumerate(asteriks): + if not i == list(asteriks.index).index(self.model_name): + x2 = i + y = data_table[[self.model_name, data_table.columns[i]]].max().max() + y = max(y, y_prev) * 1.025 + h = .01 * data_table.max().max() + ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], c="k") + ax.text((x1 + x2) * .5, y + h, v, ha="center", va="bottom", color="k") + y_prev = y + return ax + + def set_sigificance_bars_horizontal(self, asteriks, ax, data_table): + y1 = list(asteriks.index).index(self.model_name) + x_prev = 0. + for i, v in enumerate(asteriks): + if not i == list(asteriks.index).index(self.model_name): + y2 = i + x = data_table[[self.model_name, data_table.columns[i]]].max().max() + x = max(x, x_prev) * 1.025 + h = .01 * data_table.max().max() + ax.plot([x, x+h, x+h, x], [y1, y1, y2, y2], c="k") + ax.text(x + h, (y1 + y2) * .5, v, ha="left", va="center", color="k", rotation=-90) + return ax + if __name__ == "__main__": stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] -- GitLab