diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py index 8efe71a9f3fefce331c8793ab78e8d422b00f35c..486441db2944b5223ef78715ca2541b2f8eac210 100644 --- a/mlair/model_modules/convolutional_networks.py +++ b/mlair/model_modules/convolutional_networks.py @@ -56,7 +56,7 @@ class CNNfromConfig(AbstractModelClass): """ - def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, **kwargs): + def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs): assert len(input_shape) == 1 assert len(output_shape) == 1 @@ -67,9 +67,9 @@ class CNNfromConfig(AbstractModelClass): self.activation_output = self._activation.get(activation_output) self.activation_output_name = activation_output self.kwargs = kwargs + self.optimizer = self._set_optimizer(optimizer, **kwargs) # apply to model - self.set_model() self.set_compile_options() self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) @@ -90,6 +90,19 @@ class CNNfromConfig(AbstractModelClass): self.model = keras.Model(inputs=x_input, outputs=[out]) print(self.model.summary()) + def _set_optimizer(self, optimizer, **kwargs): + try: + opt_name = optimizer.lower() + opt = self._optimizer.get(opt_name) + opt_kwargs = {} + if opt_name == "adam": + opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"]) + elif opt_name == "sgd": + opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"]) + return opt(**opt_kwargs) + except KeyError: + raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") + def _set_regularizer(self, regularizer, **kwargs): if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): return None