Skip to content
Snippets Groups Projects

Resolve "release v2.3.0"

Merged Ghost User requested to merge release_v2.3.0 into master
1 file
+ 11
1
Compare changes
  • Side-by-side
  • Inline
__author__ = "Lukas Leufen"
__date__ = "2021-08-23"
from functools import partial
from mlair.model_modules.branched_input_networks import BranchedInputCNN
@@ -46,7 +47,10 @@ class BranchedInputResNet(BranchedInputCNN):
layer_name = layer_kwargs.pop("name").split("_")
layer_name = "_".join([*layer_name[0:2], "%s", *layer_name[2:]])
act = layer_kwargs.pop("activation")
act_name = act.__name__
if isinstance(act, partial):
act_name = act.args[0] if act.func.__name__ == "Activation" else act.func.__name__
else:
act_name = act.__name__
use_1x1conv = layer_kwargs.pop("use_1x1conv", False)
def block(x):
@@ -74,6 +78,12 @@ class BranchedInputResNet(BranchedInputCNN):
kernel_initializer = self._initializer.get(activation_type, "glorot_uniform")
layer_opts["kernel_initializer"] = kernel_initializer
follow_up_layer = activation
if self.bn is True and layer_type.lower() != "residual_block":
another_layer = keras.layers.BatchNormalization
if activation_type in ["relu", "linear", "prelu", "leakyrelu"]:
follow_up_layer = (another_layer, follow_up_layer)
else:
follow_up_layer = (follow_up_layer, another_layer)
regularizer_type = layer_opts.pop("kernel_regularizer", None)
if regularizer_type is not None:
layer_opts["kernel_regularizer"] = self._set_regularizer(regularizer_type, **self.kwargs)
Loading