Skip to content
Snippets Groups Projects
Commit 0aa25bbb authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue419_feat_resnet-model-class' into 'develop'

Lukas issue419 feat resnet model class

See merge request !481
parents c1dfc1b3 656f5d16
No related branches found
No related tags found
3 merge requests!500Develop,!499Resolve "release v2.3.0",!481Lukas issue419 feat resnet model class
Pipeline #110003 passed
__author__ = "Lukas Leufen"
__date__ = "2021-08-23"
from functools import partial
from mlair.model_modules.branched_input_networks import BranchedInputCNN
......@@ -46,6 +47,9 @@ class BranchedInputResNet(BranchedInputCNN):
layer_name = layer_kwargs.pop("name").split("_")
layer_name = "_".join([*layer_name[0:2], "%s", *layer_name[2:]])
act = layer_kwargs.pop("activation")
if isinstance(act, partial):
act_name = act.args[0] if act.func.__name__ == "Activation" else act.func.__name__
else:
act_name = act.__name__
use_1x1conv = layer_kwargs.pop("use_1x1conv", False)
......@@ -74,6 +78,12 @@ class BranchedInputResNet(BranchedInputCNN):
kernel_initializer = self._initializer.get(activation_type, "glorot_uniform")
layer_opts["kernel_initializer"] = kernel_initializer
follow_up_layer = activation
if self.bn is True and layer_type.lower() != "residual_block":
another_layer = keras.layers.BatchNormalization
if activation_type in ["relu", "linear", "prelu", "leakyrelu"]:
follow_up_layer = (another_layer, follow_up_layer)
else:
follow_up_layer = (follow_up_layer, another_layer)
regularizer_type = layer_opts.pop("kernel_regularizer", None)
if regularizer_type is not None:
layer_opts["kernel_regularizer"] = self._set_regularizer(regularizer_type, **self.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment