diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py index 11da0acbc76de73c34743254fecde62a9726699e..a9621af4a93413d4a0459b24d6cf2ab62fa2d01e 100644 --- a/mlair/model_modules/convolutional_networks.py +++ b/mlair/model_modules/convolutional_networks.py @@ -112,7 +112,8 @@ class CNN(AbstractModelClass): # pragma: no cover # apply to model self.set_model() self.set_compile_options() - self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) + # self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) + self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss) def _set_pooling(self, pooling): try: @@ -221,7 +222,9 @@ class CNN(AbstractModelClass): # pragma: no cover print(self.model.summary()) def set_compile_options(self): - self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])], + # self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])], + # "metrics": ["mse", "mae", var_loss]} + self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae", var_loss]}