Skip to content
Snippets Groups Projects
Commit 1fff2055 authored by leufen1's avatar leufen1
Browse files

added batch normalization to cnn and branched cnn

parent 4ec0ab43
No related branches found
No related tags found
5 merge requests!430update recent developments,!413update release branch,!412Resolve "release v2.0.0",!406Lukas issue368 feat prepare cnn class for filter benchmarking,!403Resolve "prepare CNN class for filter benchmarking"
Pipeline #94427 passed
......@@ -4,7 +4,7 @@ import copy
from tensorflow import keras as keras
from mlair import AbstractModelClass
from mlair.helpers import select_from_dict
from mlair.helpers import select_from_dict, to_list
from mlair.model_modules.loss import var_loss
from mlair.model_modules.recurrent_networks import RNN
from mlair.model_modules.convolutional_networks import CNNfromConfig
......@@ -42,7 +42,8 @@ class BranchedInputCNN(CNNfromConfig): # pragma: no cover
layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
x_in_b = layer(**layer_kwargs, name=f"{layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b)
if follow_up_layer is not None:
x_in_b = follow_up_layer(name=f"{follow_up_layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b)
for follow_up in to_list(follow_up_layer):
x_in_b = follow_up(name=f"{follow_up.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b)
self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
"branch": branch})
x_in.append(x_in_b)
......@@ -51,12 +52,13 @@ class BranchedInputCNN(CNNfromConfig): # pragma: no cover
x_concat = keras.layers.Concatenate()(x_in)
if stop_pos is not None:
for layer_opts in self.conf[stop_pos + 1:]:
for pos, layer_opts in enumerate(self.conf[stop_pos + 1:]):
print(layer_opts)
layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
x_concat = layer(**layer_kwargs)(x_concat)
x_concat = layer(**layer_kwargs, name=f"{layer.__name__}_{pos + stop_pos + 1}")(x_concat)
if follow_up_layer is not None:
x_concat = follow_up_layer()(x_concat)
for follow_up in to_list(follow_up_layer):
x_concat = follow_up(name=f"{follow_up.__name__}_{pos + stop_pos + 1}")(x_concat)
self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
"branch": "concat"})
......@@ -73,11 +75,6 @@ class BranchedInputRNN(RNN): # pragma: no cover
super().__init__([input_shape], output_shape, *args, **kwargs)
# apply to model
# self.set_model()
# self.set_compile_options()
# self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss)
def set_model(self):
"""
Build the model.
......
......@@ -4,7 +4,7 @@ __date__ = '2021-02-'
from functools import reduce, partial
from mlair.model_modules import AbstractModelClass
from mlair.helpers import select_from_dict
from mlair.helpers import select_from_dict, to_list
from mlair.model_modules.loss import var_loss, custom_loss
from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D
......@@ -56,7 +56,8 @@ class CNNfromConfig(AbstractModelClass):
"""
def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam",
batch_normalization=False, **kwargs):
assert len(input_shape) == 1
assert len(output_shape) == 1
......@@ -67,6 +68,7 @@ class CNNfromConfig(AbstractModelClass):
self.activation_output = self._activation.get(activation_output)
self.activation_output_name = activation_output
self.kwargs = kwargs
self.bn = batch_normalization
self.optimizer = self._set_optimizer(optimizer, **kwargs)
self._layer_save = []
......@@ -84,7 +86,8 @@ class CNNfromConfig(AbstractModelClass):
layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
x_in = layer(**layer_kwargs)(x_in)
if follow_up_layer is not None:
x_in = follow_up_layer()(x_in)
for follow_up in to_list(follow_up_layer):
x_in = follow_up()(x_in)
self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer})
x_in = keras.layers.Dense(self._output_shape)(x_in)
......@@ -138,6 +141,12 @@ class CNNfromConfig(AbstractModelClass):
kernel_initializer = self._initializer.get(activation_type, "glorot_uniform")
layer_opts["kernel_initializer"] = kernel_initializer
follow_up_layer = activation
if self.bn is True:
another_layer = keras.layers.BatchNormalization
if activation_type in ["relu", "linear", "prelu", "leakyrelu"]:
follow_up_layer = (another_layer, follow_up_layer)
else:
follow_up_layer = (follow_up_layer, another_layer)
regularizer_type = layer_opts.pop("kernel_regularizer", None)
if regularizer_type is not None:
layer_opts["kernel_regularizer"] = self._set_regularizer(regularizer_type, **self.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment