diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py index 4df95867152f7871b17813a0f016ac19eb6dca92..a7841f6aab031647e0a3a6a8af6b7c648179cbc3 100644 --- a/mlair/model_modules/branched_input_networks.py +++ b/mlair/model_modules/branched_input_networks.py @@ -4,7 +4,7 @@ import copy from tensorflow import keras as keras from mlair import AbstractModelClass -from mlair.helpers import select_from_dict +from mlair.helpers import select_from_dict, to_list from mlair.model_modules.loss import var_loss from mlair.model_modules.recurrent_networks import RNN from mlair.model_modules.convolutional_networks import CNNfromConfig @@ -42,7 +42,8 @@ class BranchedInputCNN(CNNfromConfig): # pragma: no cover layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) x_in_b = layer(**layer_kwargs, name=f"{layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b) if follow_up_layer is not None: - x_in_b = follow_up_layer(name=f"{follow_up_layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b) + for follow_up in to_list(follow_up_layer): + x_in_b = follow_up(name=f"{follow_up.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b) self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer, "branch": branch}) x_in.append(x_in_b) @@ -51,12 +52,13 @@ class BranchedInputCNN(CNNfromConfig): # pragma: no cover x_concat = keras.layers.Concatenate()(x_in) if stop_pos is not None: - for layer_opts in self.conf[stop_pos + 1:]: + for pos, layer_opts in enumerate(self.conf[stop_pos + 1:]): print(layer_opts) layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) - x_concat = layer(**layer_kwargs)(x_concat) + x_concat = layer(**layer_kwargs, name=f"{layer.__name__}_{pos + stop_pos + 1}")(x_concat) if follow_up_layer is not None: - x_concat = follow_up_layer()(x_concat) + for follow_up in to_list(follow_up_layer): + x_concat = follow_up(name=f"{follow_up.__name__}_{pos + stop_pos + 1}")(x_concat) self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer, "branch": "concat"}) @@ -73,11 +75,6 @@ class BranchedInputRNN(RNN): # pragma: no cover super().__init__([input_shape], output_shape, *args, **kwargs) - # apply to model - # self.set_model() - # self.set_compile_options() - # self.set_custom_objects(loss=self.compile_options["loss"][0], var_loss=var_loss) - def set_model(self): """ Build the model. diff --git a/mlair/model_modules/convolutional_networks.py b/mlair/model_modules/convolutional_networks.py index 7f6867343800c198af5ca23538de335e1b050c16..3e87bb7c595e4a0f709ca73449043ad74e9b6e9d 100644 --- a/mlair/model_modules/convolutional_networks.py +++ b/mlair/model_modules/convolutional_networks.py @@ -4,7 +4,7 @@ __date__ = '2021-02-' from functools import reduce, partial from mlair.model_modules import AbstractModelClass -from mlair.helpers import select_from_dict +from mlair.helpers import select_from_dict, to_list from mlair.model_modules.loss import var_loss, custom_loss from mlair.model_modules.advanced_paddings import PadUtils, Padding2D, SymmetricPadding2D @@ -56,7 +56,8 @@ class CNNfromConfig(AbstractModelClass): """ - def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs): + def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", + batch_normalization=False, **kwargs): assert len(input_shape) == 1 assert len(output_shape) == 1 @@ -67,6 +68,7 @@ class CNNfromConfig(AbstractModelClass): self.activation_output = self._activation.get(activation_output) self.activation_output_name = activation_output self.kwargs = kwargs + self.bn = batch_normalization self.optimizer = self._set_optimizer(optimizer, **kwargs) self._layer_save = [] @@ -84,7 +86,8 @@ class CNNfromConfig(AbstractModelClass): layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) x_in = layer(**layer_kwargs)(x_in) if follow_up_layer is not None: - x_in = follow_up_layer()(x_in) + for follow_up in to_list(follow_up_layer): + x_in = follow_up()(x_in) self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer}) x_in = keras.layers.Dense(self._output_shape)(x_in) @@ -138,6 +141,12 @@ class CNNfromConfig(AbstractModelClass): kernel_initializer = self._initializer.get(activation_type, "glorot_uniform") layer_opts["kernel_initializer"] = kernel_initializer follow_up_layer = activation + if self.bn is True: + another_layer = keras.layers.BatchNormalization + if activation_type in ["relu", "linear", "prelu", "leakyrelu"]: + follow_up_layer = (another_layer, follow_up_layer) + else: + follow_up_layer = (follow_up_layer, another_layer) regularizer_type = layer_opts.pop("kernel_regularizer", None) if regularizer_type is not None: layer_opts["kernel_regularizer"] = self._set_regularizer(regularizer_type, **self.kwargs)