Skip to content
Snippets Groups Projects
Commit 48ddb38d authored by leufen1's avatar leufen1
Browse files

added optimizer to CNNfromConfig

parent 259e00b3
No related branches found
No related tags found
5 merge requests!430update recent developments,!413update release branch,!412Resolve "release v2.0.0",!406Lukas issue368 feat prepare cnn class for filter benchmarking,!403Resolve "prepare CNN class for filter benchmarking"
Pipeline #93783 passed
...@@ -56,7 +56,7 @@ class CNNfromConfig(AbstractModelClass): ...@@ -56,7 +56,7 @@ class CNNfromConfig(AbstractModelClass):
""" """
def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, **kwargs): def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
assert len(input_shape) == 1 assert len(input_shape) == 1
assert len(output_shape) == 1 assert len(output_shape) == 1
...@@ -67,9 +67,9 @@ class CNNfromConfig(AbstractModelClass): ...@@ -67,9 +67,9 @@ class CNNfromConfig(AbstractModelClass):
self.activation_output = self._activation.get(activation_output) self.activation_output = self._activation.get(activation_output)
self.activation_output_name = activation_output self.activation_output_name = activation_output
self.kwargs = kwargs self.kwargs = kwargs
self.optimizer = self._set_optimizer(optimizer, **kwargs)
# apply to model # apply to model
self.set_model() self.set_model()
self.set_compile_options() self.set_compile_options()
self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss) self.set_custom_objects(loss=custom_loss([keras.losses.mean_squared_error, var_loss]), var_loss=var_loss)
...@@ -90,6 +90,19 @@ class CNNfromConfig(AbstractModelClass): ...@@ -90,6 +90,19 @@ class CNNfromConfig(AbstractModelClass):
self.model = keras.Model(inputs=x_input, outputs=[out]) self.model = keras.Model(inputs=x_input, outputs=[out])
print(self.model.summary()) print(self.model.summary())
def _set_optimizer(self, optimizer, **kwargs):
try:
opt_name = optimizer.lower()
opt = self._optimizer.get(opt_name)
opt_kwargs = {}
if opt_name == "adam":
opt_kwargs = select_from_dict(kwargs, ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad"])
elif opt_name == "sgd":
opt_kwargs = select_from_dict(kwargs, ["lr", "momentum", "decay", "nesterov"])
return opt(**opt_kwargs)
except KeyError:
raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.")
def _set_regularizer(self, regularizer, **kwargs): def _set_regularizer(self, regularizer, **kwargs):
if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"):
return None return None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment