From a997bce528a470892e037e5dc87228e45b339230 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 16 Sep 2022 10:41:59 +0200
Subject: [PATCH] implemented mean error (bias) as additional error metric

---
 mlair/helpers/statistics.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index 5f3aa451..ad344ab3 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -213,6 +213,11 @@ def mean_absolute_error(a, b, dim=None):
     return np.abs(a - b).mean(dim)
 
 
+def mean_error(a, b, dim=None):
+    """Calculate mean error where a is forecast and b the reference (e.g. observation)."""
+    return a.mean(dim) - b.mean(dim)
+
+
 def index_of_agreement(a, b, dim=None):
     """Calculate index of agreement (IOA) where a is the forecast and b the reference (e.g. observation)."""
     num = (np.square(b - a)).sum(dim)
@@ -234,7 +239,7 @@ def modified_normalized_mean_bias(a, b, dim=None):
 
 
 def calculate_error_metrics(a, b, dim):
-    """Calculate MSE, RMSE, MAE, IOA, and MNMB. Additionally, return number of used values for calculation.
+    """Calculate MSE, ME, RMSE, MAE, IOA, and MNMB. Additionally, return number of used values for calculation.
 
     :param a: forecast data to calculate metrics for
     :param b: reference (e.g. observation)
@@ -243,12 +248,13 @@ def calculate_error_metrics(a, b, dim):
     :returns: dict with results for all metrics indicated by lowercase metric short name
     """
     mse = mean_squared_error(a, b, dim)
+    me = mean_error(a, b, dim)
     rmse = np.sqrt(mse)
     mae = mean_absolute_error(a, b, dim)
     ioa = index_of_agreement(a, b, dim)
     mnmb = modified_normalized_mean_bias(a, b, dim)
     n = (a - b).notnull().sum(dim)
-    return {"mse": mse, "rmse": rmse, "mae": mae, "ioa": ioa, "mnmb": mnmb, "n": n}
+    return {"mse": mse, "me": me, "rmse": rmse, "mae": mae, "ioa": ioa, "mnmb": mnmb, "n": n}
 
 
 def mann_whitney_u_test(data: pd.DataFrame, reference_col_name: str, **kwargs):
-- 
GitLab