From 48ddb38d9551fea0eedc2685e1559c4f425abd5d Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Wed, 2 Mar 2022 15:22:38 +0100 Subject: [PATCH] added optimizer to CNNfromConfig --- mlair/model_modules/convolutional_networks.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py index 8efe71a9..486441db 100644 --- a/mlair/model_modules/convolutional_networks.py +++ b/mlair/model_modules/convolutional_networks.py @@ -56,7 +56,7 @@ class CNNfromConfig(AbstractModelClass): """ - def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, **kwargs): + def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs): assert len(input_shape) == 1 assert len(output_shape) == 1 @@ -67,9 +67,9 @@ class CNNfromConfig(AbstractModelClass): self.activation_output = self._activation.get(activation_output) self.activation_output_name = activation_output self.kwargs = kwargs + self.optimizer = self._set_optimizer(optimizer, **kwargs) # apply to model - self.set_model() self.set_compile_options() self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) @@ -90,6 +90,19 @@ class CNNfromConfig(AbstractModelClass): self.model = keras.Model(inputs=x_input, outputs=[out]) print(self.model.summary()) + def _set_optimizer(self, optimizer, **kwargs): + try: + opt_name = optimizer.lower() + opt = self._optimizer.get(opt_name) + opt_kwargs = {} + if opt_name == "adam": + opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"]) + elif opt_name == "sgd": + opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"]) + return opt(**opt_kwargs) + except KeyError: + raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") + def _set_regularizer(self, regularizer, **kwargs): if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): return None -- GitLab