From c9ec2f00f3a21b57142b5f16ac39a11454afbcdc Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Tue, 25 Jan 2022 18:56:58 +0100
Subject: [PATCH] frst implementation of mw-utest in plots

---
 mlair/helpers/statistics.py               | 42 +++++++++++
 mlair/plotting/postprocessing_plotting.py | 88 ++++++++++++++++++++---
 2 files changed, 120 insertions(+), 10 deletions(-)

diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index 19a4893d..d1388a35 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -10,6 +10,7 @@ import xarray as xr
 import pandas as pd
 from typing import Union, Tuple, Dict, List
 import itertools
+from collections import OrderedDict
 
 Data = Union[xr.DataArray, pd.DataFrame]
 
@@ -219,6 +220,47 @@ def calculate_error_metrics(a, b, dim):
     return {"mse": mse, "rmse": rmse, "mae": mae, "n": n}
 
 
+def mann_whitney_u_test(data: pd.DataFrame, reference_col_name: str, **kwargs):
+    """
+    Calculate Mann-Whitney u-test. Uses pandas' .apply() on scipy.stats.mannwhitneyu(x, y, ...).
+    :param data:
+    :type data:
+    :param reference_col_name: Name of column which is used for comparison (y)
+    :type reference_col_name:
+    :param kwargs:
+    :type kwargs:
+    :return:
+    :rtype:
+    """
+    res = data.apply(stats.mannwhitneyu, y=data[reference_col_name], **kwargs)
+    res = res.rename(index={0: "statistics", 1: "p-value"})
+    return res
+
+
+def represent_p_values_as_asteriks(p_values: pd.Series, threshold_representation: OrderedDict = None):
+    """
+    Represent p-values as asteriks based on its value.
+    :param p_values:
+    :type p_values:
+    :param threshold_representation:
+    :type threshold_representation:
+    :return:
+    :rtype:
+    """
+    if threshold_representation is None:
+        threshold_representation = OrderedDict([(1, "ns"), (0.05, "*"), (0.01, "**"), (0.001, "***")])
+
+    if not all(x > y for x, y in zip(list(threshold_representation.keys()), list(threshold_representation.keys())[1:])):
+        raise ValueError(
+            f"`threshold_representation' keys mus be in strictly "
+            f"decreasing order but is: {threshold_representation.keys()}")
+
+    asteriks = pd.Series().reindex_like(p_values)
+    for k, v in threshold_representation.items():
+        asteriks[p_values < k] = v
+    return asteriks
+
+
 class SkillScores:
     r"""
     Calculate different kinds of skill scores.
diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 748476b8..4799687f 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -16,11 +16,13 @@ import seaborn as sns
 import xarray as xr
 from matplotlib.backends.backend_pdf import PdfPages
 from matplotlib.offsetbox import AnchoredText
+from scipy.stats import mannwhitneyu
 
 from mlair import helpers
 from mlair.data_handler.iterator import DataCollection
 from mlair.helpers import TimeTrackingWrapper
 from mlair.plotting.abstract_plot_class import AbstractPlotClass
+from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks
 
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
@@ -1092,24 +1094,51 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
         self.dim_name_boots = dim_name_boots
         self.error_unit = error_unit
         self.block_length = block_length
+        self.model_name = model_name
         data = self.rename_model_indicator(data, model_name, model_indicator)
         self.prepare_data(data)
-        self._plot(orientation="v")
 
-        self.plot_name = default_name + "_horizontal"
-        self._plot(orientation="h")
+        variants = [("_vertical", "v", False), ("_vertical_u_test", "v", True),
+                    ("_horizontal", "h", False), ("_horizontal_u_test", "h", True)]
+
+        for name_tag, orientation, utest in variants:
+            self.plot_name = default_name + name_tag
+            self._plot(orientation=orientation, apply_u_test=utest)
+
+        # self._plot(orientation="v", apply_u_test=False)
+        #
+        # self.plot_name = default_name + "_u_test"
+        # self._plot(orientation="v", apply_u_test=True)
+        #
+        # self.plot_name = default_name + "_horizontal"
+        # self._plot(orientation="h", apply_u_test=False)
+        #
+        # self.plot_name = default_name + "_horizontal_u_test"
+        # self._plot(orientation="h", apply_u_test=True)
 
         self._apply_root()
 
-        self.plot_name = default_name + "_sqrt"
-        self._plot(orientation="v")
+        variants = [(tag+"_sqrt", ori, ut) for tag, ori, ut in variants]
+
+        for name_tag, orientation, utest in variants:
+            self.plot_name = default_name + name_tag
+            self._plot(orientation=orientation, apply_u_test=utest)
 
-        self.plot_name = default_name + "_horizontal_sqrt"
-        self._plot(orientation="h")
+        # self.plot_name = default_name + "_sqrt"
+        # self._plot(orientation="v")
+        #
+        # self.plot_name = default_name + "_horizontal_sqrt"
+        # self._plot(orientation="h")
 
         self._data_table = None
         self._n_boots = None
 
+    @property
+    def get_asteriks_from_mann_whitney_u_result(self):
+        return represent_p_values_as_asteriks(mann_whitney_u_test(data=self._data_table,
+                                                                  reference_col_name=self.model_name,
+                                                                  axis=0, alternative="two-sided").iloc[-1])
+
     def rename_model_indicator(self, data, model_name, model_indicator):
         data.coords[self.model_type_dim] = [{model_indicator: model_name}.get(n, n)
                                             for n in data.coords[self.model_type_dim].values]
@@ -1125,10 +1154,11 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
         self.error_measure = f"root {self.error_measure}"
         self.error_unit = self.error_unit.replace("$^2$", "")
 
-    def _plot(self, orientation: str = "v"):
+    def _plot(self, orientation: str = "v", apply_u_test: bool = False):
         data_table = self._data_table
         n_boots = self._n_boots
         size = len(np.unique(data_table.columns))
+        asteriks = self.get_asteriks_from_mann_whitney_u_result
         if orientation == "v":
             figsize, width = (size, 5), 0.4
         elif orientation == "h":
@@ -1142,15 +1172,26 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
                     boxprops={'facecolor': 'none', 'edgecolor': 'k'},
                     width=width, orient=orientation)
         if orientation == "v":
+            if apply_u_test:
+                ax = self.set_sigificance_bars_vertical(asteriks, ax, data_table)
+            ylims = list(ax.get_ylim())
+            ax.set_ylim([ylims[0], ylims[1]*1.025])
             ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})")
             ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
         elif orientation == "h":
+            if apply_u_test:
+                ax = self.set_sigificance_bars_horizontal(asteriks, ax, data_table)
             ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})")
+            xlims = list(ax.get_xlim())
+            ax.set_xlim([xlims[0], xlims[1] * 1.015])
+
         else:
             raise ValueError(f"orientation must be `v' or `h' but is: {orientation}")
         text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}"
-        loc = "upper right" if orientation == "h" else "upper left"
-        text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5)
+        # loc = "upper right" if orientation == "h" else "upper left"
+        loc = "lower left"
+        text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0),
+                       bbox_transform=ax.transAxes)
         plt.setp(text_box.patch, edgecolor='k', facecolor='w')
         ax.add_artist(text_box)
         plt.setp(ax.lines, color='k')
@@ -1158,6 +1199,33 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
         self._save()
         plt.close("all")
 
+    def set_sigificance_bars_vertical(self, asteriks, ax, data_table):
+        x1 = list(asteriks.index).index(self.model_name)
+        y_prev = 0.
+        for i, v in enumerate(asteriks):
+            if not i == list(asteriks.index).index(self.model_name):
+                x2 = i
+                y = data_table[[self.model_name, data_table.columns[i]]].max().max()
+                y = max(y, y_prev) * 1.025
+                h = .01 * data_table.max().max()
+                ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], c="k")
+                ax.text((x1 + x2) * .5, y + h, v, ha="center", va="bottom", color="k")
+                y_prev = y
+        return ax
+
+    def set_sigificance_bars_horizontal(self, asteriks, ax, data_table):
+        y1 = list(asteriks.index).index(self.model_name)
+        x_prev = 0.
+        for i, v in enumerate(asteriks):
+            if not i == list(asteriks.index).index(self.model_name):
+                y2 = i
+                x = data_table[[self.model_name, data_table.columns[i]]].max().max()
+                x = max(x, x_prev) * 1.025
+                h = .01 * data_table.max().max()
+                ax.plot([x, x+h, x+h, x], [y1, y1, y2, y2], c="k")
+                ax.text(x + h, (y1 + y2) * .5, v, ha="left", va="center", color="k", rotation=-90)
+        return ax
+
 
 if __name__ == "__main__":
     stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
-- 
GitLab