From 1fea8025b178f83b2927a29f040ae1c58456dcab Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 1 Mar 2022 17:51:11 +0100
Subject: [PATCH] changed default loss

---
 mlair/model_modules/convolutional_networks.py | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py
index 11da0acb..a9621af4 100644
--- a/mlair/model_modules/convolutional_networks.py
+++ b/mlair/model_modules/convolutional_networks.py
@@ -112,7 +112,8 @@ class CNN(AbstractModelClass):  # pragma: no cover
         # apply to model
         self.set_model()
         self.set_compile_options()
-        self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss)
+        # self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss)
+        self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
 
     def _set_pooling(self, pooling):
         try:
@@ -221,7 +222,9 @@ class CNN(AbstractModelClass):  # pragma: no cover
         print(self.model.summary())
 
     def set_compile_options(self):
-        self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])],
+        # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])],
+        #                         "metrics": ["mse", "mae", var_loss]}
+        self.compile_options = {"loss": [keras.losses.mean_squared_error],
                                 "metrics": ["mse", "mae", var_loss]}
 
 
-- 
GitLab