diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py
index 8efe71a9f3fefce331c8793ab78e8d422b00f35c..486441db2944b5223ef78715ca2541b2f8eac210 100644
--- a/mlair/model_modules/convolutional_networks.py
+++ b/mlair/model_modules/convolutional_networks.py
@@ -56,7 +56,7 @@ class CNNfromConfig(AbstractModelClass):
 
     """
 
-    def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, **kwargs):
+    def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
 
         assert len(input_shape) == 1
         assert len(output_shape) == 1
@@ -67,9 +67,9 @@ class CNNfromConfig(AbstractModelClass):
         self.activation_output = self._activation.get(activation_output)
         self.activation_output_name = activation_output
         self.kwargs = kwargs
+        self.optimizer = self._set_optimizer(optimizer, **kwargs)
 
         # apply to model
-
         self.set_model()
         self.set_compile_options()
         self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss)
@@ -90,6 +90,19 @@ class CNNfromConfig(AbstractModelClass):
         self.model = keras.Model(inputs=x_input, outputs=[out])
         print(self.model.summary())
 
+    def _set_optimizer(self, optimizer, **kwargs):
+        try:
+            opt_name = optimizer.lower()
+            opt = self._optimizer.get(opt_name)
+            opt_kwargs = {}
+            if opt_name == "adam":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
+            elif opt_name == "sgd":
+                opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
+            return opt(**opt_kwargs)
+        except KeyError:
+            raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
+
     def _set_regularizer(self, regularizer, **kwargs):
         if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
             return None