From 363994d40a9c579545123a11b81e25c6d1ca040d Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 3 Mar 2021 17:58:58 +0100
Subject: [PATCH] new methods to calculate mae and a bunch of error metrics

---
 mlair/helpers/statistics.py | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py
index ad6a368f..3631597a 100644
--- a/mlair/helpers/statistics.py
+++ b/mlair/helpers/statistics.py
@@ -196,9 +196,23 @@ def log_apply(data: Data, mean: Data, std: Data) -> Data:
     return standardise_apply(np.log1p(data), mean, std)
 
 
-def mean_squared_error(a, b):
+def mean_squared_error(a, b, dim=None):
     """Calculate mean squared error."""
-    return np.square(a - b).mean()
+    return np.square(a - b).mean(dim)
+
+
+def mean_absolute_error(a, b, dim=None):
+    """Calculate mean absolute error."""
+    return np.abs(a - b).mean(dim)
+
+
+def calculate_error_metrics(a, b, dim):
+    """Calculate MSE, RMSE, and MAE. Additionally return number of used values for calculation."""
+    mse = mean_squared_error(a, b, dim)
+    rmse = np.sqrt(mse)
+    mae = mean_absolute_error(a, b, dim)
+    n = (a - b).notnull().sum(dim)
+    return {"mse": mse, "rmse": rmse, "mae": mae, "n": n}
 
 
 class SkillScores:
-- 
GitLab