From 48ddb38d9551fea0eedc2685e1559c4f425abd5d Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Wed, 2 Mar 2022 15:22:38 +0100
Subject: [PATCH] added optimizer to CNNfromConfig

---
 mlair/model_modules/convolutional_networks.py | 17 +++++++++++++++--
 1 file changed, 15 insertions(+), 2 deletions(-)

diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py
index 8efe71a9..486441db 100644
--- a/mlair/model_modules/convolutional_networks.py
+++ b/mlair/model_modules/convolutional_networks.py
@@ -56,7 +56,7 @@ class CNNfromConfig(AbstractModelClass):
 
     """
 
-    def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, **kwargs):
+    def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
 
         assert len(input_shape) == 1
         assert len(output_shape) == 1
@@ -67,9 +67,9 @@ class CNNfromConfig(AbstractModelClass):
         self.activation_output = self._activation.get(activation_output)
         self.activation_output_name = activation_output
         self.kwargs = kwargs
+        self.optimizer = self._set_optimizer(optimizer, **kwargs)
 
         # apply to model
-
         self.set_model()
         self.set_compile_options()
         self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss)
@@ -90,6 +90,19 @@ class CNNfromConfig(AbstractModelClass):
         self.model = keras.Model(inputs=x_input, outputs=[out])
         print(self.model.summary())
 
+    def _set_optimizer(self, optimizer, **kwargs):
+        try:
+            opt_name = optimizer.lower()
+            opt = self._optimizer.get(opt_name)
+            opt_kwargs = {}
+            if opt_name == "adam":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
+            elif opt_name == "sgd":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
+            return opt(**opt_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
+
     def _set_regularizer(self, regularizer, **kwargs):
         if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
             return None
-- 
GitLab