Skip to content
Snippets Groups Projects
Commit c9ec2f00 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

frst implementation of mw-utest in plots

parent 6560872e
Branches
Tags
4 merge requests!430update recent developments,!413update release branch,!412Resolve "release v2.0.0",!380Resolve "Include Mann-Whitney U rank test"
Pipeline #89682 passed with warnings
...@@ -10,6 +10,7 @@ import xarray as xr ...@@ -10,6 +10,7 @@ import xarray as xr
import pandas as pd import pandas as pd
from typing import Union, Tuple, Dict, List from typing import Union, Tuple, Dict, List
import itertools import itertools
from collections import OrderedDict
Data = Union[xr.DataArray, pd.DataFrame] Data = Union[xr.DataArray, pd.DataFrame]
...@@ -219,6 +220,47 @@ def calculate_error_metrics(a, b, dim): ...@@ -219,6 +220,47 @@ def calculate_error_metrics(a, b, dim):
return {"mse": mse, "rmse": rmse, "mae": mae, "n": n} return {"mse": mse, "rmse": rmse, "mae": mae, "n": n}
def mann_whitney_u_test(data: pd.DataFrame, reference_col_name: str, **kwargs):
"""
Calculate Mann-Whitney u-test. Uses pandas' .apply() on scipy.stats.mannwhitneyu(x, y, ...).
:param data:
:type data:
:param reference_col_name: Name of column which is used for comparison (y)
:type reference_col_name:
:param kwargs:
:type kwargs:
:return:
:rtype:
"""
res = data.apply(stats.mannwhitneyu, y=data[reference_col_name], **kwargs)
res = res.rename(index={0: "statistics", 1: "p-value"})
return res
def represent_p_values_as_asteriks(p_values: pd.Series, threshold_representation: OrderedDict = None):
"""
Represent p-values as asteriks based on its value.
:param p_values:
:type p_values:
:param threshold_representation:
:type threshold_representation:
:return:
:rtype:
"""
if threshold_representation is None:
threshold_representation = OrderedDict([(1, "ns"), (0.05, "*"), (0.01, "**"), (0.001, "***")])
if not all(x > y for x, y in zip(list(threshold_representation.keys()), list(threshold_representation.keys())[1:])):
raise ValueError(
f"`threshold_representation' keys mus be in strictly "
f"decreasing order but is: {threshold_representation.keys()}")
asteriks = pd.Series().reindex_like(p_values)
for k, v in threshold_representation.items():
asteriks[p_values < k] = v
return asteriks
class SkillScores: class SkillScores:
r""" r"""
Calculate different kinds of skill scores. Calculate different kinds of skill scores.
......
...@@ -16,11 +16,13 @@ import seaborn as sns ...@@ -16,11 +16,13 @@ import seaborn as sns
import xarray as xr import xarray as xr
from matplotlib.backends.backend_pdf import PdfPages from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.offsetbox import AnchoredText from matplotlib.offsetbox import AnchoredText
from scipy.stats import mannwhitneyu
from mlair import helpers from mlair import helpers
from mlair.data_handler.iterator import DataCollection from mlair.data_handler.iterator import DataCollection
from mlair.helpers import TimeTrackingWrapper from mlair.helpers import TimeTrackingWrapper
from mlair.plotting.abstract_plot_class import AbstractPlotClass from mlair.plotting.abstract_plot_class import AbstractPlotClass
from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks
logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING)
...@@ -1092,24 +1094,51 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover ...@@ -1092,24 +1094,51 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
self.dim_name_boots = dim_name_boots self.dim_name_boots = dim_name_boots
self.error_unit = error_unit self.error_unit = error_unit
self.block_length = block_length self.block_length = block_length
self.model_name = model_name
data = self.rename_model_indicator(data, model_name, model_indicator) data = self.rename_model_indicator(data, model_name, model_indicator)
self.prepare_data(data) self.prepare_data(data)
self._plot(orientation="v")
self.plot_name = default_name + "_horizontal" variants = [("_vertical", "v", False), ("_vertical_u_test", "v", True),
self._plot(orientation="h") ("_horizontal", "h", False), ("_horizontal_u_test", "h", True)]
for name_tag, orientation, utest in variants:
self.plot_name = default_name + name_tag
self._plot(orientation=orientation, apply_u_test=utest)
# self._plot(orientation="v", apply_u_test=False)
#
# self.plot_name = default_name + "_u_test"
# self._plot(orientation="v", apply_u_test=True)
#
# self.plot_name = default_name + "_horizontal"
# self._plot(orientation="h", apply_u_test=False)
#
# self.plot_name = default_name + "_horizontal_u_test"
# self._plot(orientation="h", apply_u_test=True)
self._apply_root() self._apply_root()
self.plot_name = default_name + "_sqrt" variants = [(tag+"_sqrt", ori, ut) for tag, ori, ut in variants]
self._plot(orientation="v")
for name_tag, orientation, utest in variants:
self.plot_name = default_name + name_tag
self._plot(orientation=orientation, apply_u_test=utest)
self.plot_name = default_name + "_horizontal_sqrt" # self.plot_name = default_name + "_sqrt"
self._plot(orientation="h") # self._plot(orientation="v")
#
# self.plot_name = default_name + "_horizontal_sqrt"
# self._plot(orientation="h")
self._data_table = None self._data_table = None
self._n_boots = None self._n_boots = None
@property
def get_asteriks_from_mann_whitney_u_result(self):
return represent_p_values_as_asteriks(mann_whitney_u_test(data=self._data_table,
reference_col_name=self.model_name,
axis=0, alternative="two-sided").iloc[-1])
def rename_model_indicator(self, data, model_name, model_indicator): def rename_model_indicator(self, data, model_name, model_indicator):
data.coords[self.model_type_dim] = [{model_indicator: model_name}.get(n, n) data.coords[self.model_type_dim] = [{model_indicator: model_name}.get(n, n)
for n in data.coords[self.model_type_dim].values] for n in data.coords[self.model_type_dim].values]
...@@ -1125,10 +1154,11 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover ...@@ -1125,10 +1154,11 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
self.error_measure = f"root {self.error_measure}" self.error_measure = f"root {self.error_measure}"
self.error_unit = self.error_unit.replace("$^2$", "") self.error_unit = self.error_unit.replace("$^2$", "")
def _plot(self, orientation: str = "v"): def _plot(self, orientation: str = "v", apply_u_test: bool = False):
data_table = self._data_table data_table = self._data_table
n_boots = self._n_boots n_boots = self._n_boots
size = len(np.unique(data_table.columns)) size = len(np.unique(data_table.columns))
asteriks = self.get_asteriks_from_mann_whitney_u_result
if orientation == "v": if orientation == "v":
figsize, width = (size, 5), 0.4 figsize, width = (size, 5), 0.4
elif orientation == "h": elif orientation == "h":
...@@ -1142,15 +1172,26 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover ...@@ -1142,15 +1172,26 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
boxprops={'facecolor': 'none', 'edgecolor': 'k'}, boxprops={'facecolor': 'none', 'edgecolor': 'k'},
width=width, orient=orientation) width=width, orient=orientation)
if orientation == "v": if orientation == "v":
if apply_u_test:
ax = self.set_sigificance_bars_vertical(asteriks, ax, data_table)
ylims = list(ax.get_ylim())
ax.set_ylim([ylims[0], ylims[1]*1.025])
ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})") ax.set_ylabel(f"{self.error_measure} (in {self.error_unit})")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45) ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
elif orientation == "h": elif orientation == "h":
if apply_u_test:
ax = self.set_sigificance_bars_horizontal(asteriks, ax, data_table)
ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})") ax.set_xlabel(f"{self.error_measure} (in {self.error_unit})")
xlims = list(ax.get_xlim())
ax.set_xlim([xlims[0], xlims[1] * 1.015])
else: else:
raise ValueError(f"orientation must be `v' or `h' but is: {orientation}") raise ValueError(f"orientation must be `v' or `h' but is: {orientation}")
text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}" text = f"n={n_boots}" if self.block_length is None else f"{self.block_length}, n={n_boots}"
loc = "upper right" if orientation == "h" else "upper left" # loc = "upper right" if orientation == "h" else "upper left"
text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5) loc = "lower left"
text_box = AnchoredText(text, frameon=True, loc=loc, pad=0.5, bbox_to_anchor=(0., 1.0),
bbox_transform=ax.transAxes)
plt.setp(text_box.patch, edgecolor='k', facecolor='w') plt.setp(text_box.patch, edgecolor='k', facecolor='w')
ax.add_artist(text_box) ax.add_artist(text_box)
plt.setp(ax.lines, color='k') plt.setp(ax.lines, color='k')
...@@ -1158,6 +1199,33 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover ...@@ -1158,6 +1199,33 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass): # pragma: no cover
self._save() self._save()
plt.close("all") plt.close("all")
def set_sigificance_bars_vertical(self, asteriks, ax, data_table):
x1 = list(asteriks.index).index(self.model_name)
y_prev = 0.
for i, v in enumerate(asteriks):
if not i == list(asteriks.index).index(self.model_name):
x2 = i
y = data_table[[self.model_name, data_table.columns[i]]].max().max()
y = max(y, y_prev) * 1.025
h = .01 * data_table.max().max()
ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], c="k")
ax.text((x1 + x2) * .5, y + h, v, ha="center", va="bottom", color="k")
y_prev = y
return ax
def set_sigificance_bars_horizontal(self, asteriks, ax, data_table):
y1 = list(asteriks.index).index(self.model_name)
x_prev = 0.
for i, v in enumerate(asteriks):
if not i == list(asteriks.index).index(self.model_name):
y2 = i
x = data_table[[self.model_name, data_table.columns[i]]].max().max()
x = max(x, x_prev) * 1.025
h = .01 * data_table.max().max()
ax.plot([x, x+h, x+h, x], [y1, y1, y2, y2], c="k")
ax.text(x + h, (y1 + y2) * .5, v, ha="left", va="center", color="k", rotation=-90)
return ax
if __name__ == "__main__": if __name__ == "__main__":
stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment