diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py index 2c62c3cafc1537979e4a21bdb3bb6aa798e6e193..4df95867152f7871b17813a0f016ac19eb6dca92 100644 --- a/mlair/model_modules/branched_input_networks.py +++ b/mlair/model_modules/branched_input_networks.py @@ -1,4 +1,5 @@ from functools import partial, reduce +import copy from tensorflow import keras as keras @@ -6,6 +7,63 @@ from mlair import AbstractModelClass from mlair.helpers import select_from_dict from mlair.model_modules.loss import var_loss from mlair.model_modules.recurrent_networks import RNN +from mlair.model_modules.convolutional_networks import CNNfromConfig + + +class BranchedInputCNN(CNNfromConfig): # pragma: no cover + """A convolutional neural network with multiple input branches.""" + + def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs): + + super().__init__([input_shape], output_shape, layer_configuration, optimizer=optimizer, **kwargs) + + def set_model(self): + + x_input = [] + x_in = [] + stop_pos = None + + for branch in range(len(self._input_shape)): + print(branch) + shape_b = self._input_shape[branch] + x_input_b = keras.layers.Input(shape=shape_b, name=f"input_branch{branch + 1}") + x_input.append(x_input_b) + x_in_b = x_input_b + b_conf = copy.deepcopy(self.conf) + + for pos, layer_opts in enumerate(b_conf): + print(layer_opts) + if layer_opts.get("type") == "Concatenate": + if stop_pos is None: + stop_pos = pos + else: + assert pos == stop_pos + break + layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) + x_in_b = layer(**layer_kwargs, name=f"{layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b) + if follow_up_layer is not None: + x_in_b = follow_up_layer(name=f"{follow_up_layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b) + self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer, + "branch": branch}) + x_in.append(x_in_b) + + print("concat") + x_concat = keras.layers.Concatenate()(x_in) + + if stop_pos is not None: + for layer_opts in self.conf[stop_pos + 1:]: + print(layer_opts) + layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts) + x_concat = layer(**layer_kwargs)(x_concat) + if follow_up_layer is not None: + x_concat = follow_up_layer()(x_concat) + self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer, + "branch": "concat"}) + + x_concat = keras.layers.Dense(self._output_shape)(x_concat) + out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat) + self.model = keras.Model(inputs=x_input, outputs=[out]) + print(self.model.summary()) class BranchedInputRNN(RNN): # pragma: no cover