From 1fea8025b178f83b2927a29f040ae1c58456dcab Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 1 Mar 2022 17:51:11 +0100 Subject: [PATCH] changed default loss --- mlair/model_modules/convolutional_networks.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py index 11da0acb..a9621af4 100644 --- a/mlair/model_modules/convolutional_networks.py +++ b/mlair/model_modules/convolutional_networks.py @@ -112,7 +112,8 @@ class CNN(AbstractModelClass): # pragma: no cover # apply to model self.set_model() self.set_compile_options() - self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) + # self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) + self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss) def _set_pooling(self, pooling): try: @@ -221,7 +222,9 @@ class CNN(AbstractModelClass): # pragma: no cover print(self.model.summary()) def set_compile_options(self): - self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])], + # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])], + # "metrics": ["mse", "mae", var_loss]} + self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae", var_loss]} -- GitLab