diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 113765e0d295bb0b1d756cd1cefba85093b20089..6c993d56b540cf3cf5b86d9c1920fc3a22557e46 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -223,7 +223,7 @@ class Training(RunEnvironment):
         if multiple_branches_used:
             filename = os.path.join(path, f"{name}_history_main_loss.pdf")
             PlotModelHistory(filename=filename, history=history, main_branch=True)
-        if "mean_squared_error" in history.model.metrics_names:
+        if len([e for e in history.model.metrics_names if "mean_squared_error" in e]) > 0:
             filename = os.path.join(path, f"{name}_history_main_mse.pdf")
             PlotModelHistory(filename=filename, history=history, plot_metric="mse", main_branch=multiple_branches_used)