diff --git a/mlair/helpers/statistics.py b/mlair/helpers/statistics.py index ad6a368fdf7980639802412201e964def80669b2..3631597aedb90b3411163a42490e9c023bad706a 100644 --- a/mlair/helpers/statistics.py +++ b/mlair/helpers/statistics.py @@ -196,9 +196,23 @@ def log_apply(data: Data, mean: Data, std: Data) -> Data: return standardise_apply(np.log1p(data), mean, std) -def mean_squared_error(a, b): +def mean_squared_error(a, b, dim=None): """Calculate mean squared error.""" - return np.square(a - b).mean() + return np.square(a - b).mean(dim) + + +def mean_absolute_error(a, b, dim=None): + """Calculate mean absolute error.""" + return np.abs(a - b).mean(dim) + + +def calculate_error_metrics(a, b, dim): + """Calculate MSE, RMSE, and MAE. Additionally return number of used values for calculation.""" + mse = mean_squared_error(a, b, dim) + rmse = np.sqrt(mse) + mae = mean_absolute_error(a, b, dim) + n = (a - b).notnull().sum(dim) + return {"mse": mse, "rmse": rmse, "mae": mae, "n": n} class SkillScores: