From e5975374f43ce49da0bb69cfdcb4861e33940666 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Wed, 24 Aug 2022 09:27:14 +0200 Subject: [PATCH] enable bn for layers that are not residual blocks --- mlair/model_modules/residual_networks.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlair/model_modules/residual_networks.py b/mlair/model_modules/residual_networks.py index 5542c1da..2c8f3cfb 100644 --- a/mlair/model_modules/residual_networks.py +++ b/mlair/model_modules/residual_networks.py @@ -74,6 +74,12 @@ class BranchedInputResNet(BranchedInputCNN): kernel_initializer = self._initializer.get(activation_type, "glorot_uniform") layer_opts["kernel_initializer"] = kernel_initializer follow_up_layer = activation + if self.bn is True and layer_type.lower() != "residual_block": + another_layer = keras.layers.BatchNormalization + if activation_type in ["relu", "linear", "prelu", "leakyrelu"]: + follow_up_layer = (another_layer, follow_up_layer) + else: + follow_up_layer = (follow_up_layer, another_layer) regularizer_type = layer_opts.pop("kernel_regularizer", None) if regularizer_type is not None: layer_opts["kernel_regularizer"] = self._set_regularizer(regularizer_type, **self.kwargs) -- GitLab