Skip to content
Snippets Groups Projects
Commit 4ec0ab43 authored by leufen1's avatar leufen1
Browse files

introduce new class BranchedInputCNN

parent c0ac118c
No related branches found
No related tags found
Loading
Pipeline #94354 passed
from functools import partial, reduce from functools import partial, reduce
import copy
from tensorflow import keras as keras from tensorflow import keras as keras
...@@ -6,6 +7,63 @@ from mlair import AbstractModelClass ...@@ -6,6 +7,63 @@ from mlair import AbstractModelClass
from mlair.helpers import select_from_dict from mlair.helpers import select_from_dict
from mlair.model_modules.loss import var_loss from mlair.model_modules.loss import var_loss
from mlair.model_modules.recurrent_networks import RNN from mlair.model_modules.recurrent_networks import RNN
from mlair.model_modules.convolutional_networks import CNNfromConfig
class BranchedInputCNN(CNNfromConfig): # pragma: no cover
"""A convolutional neural network with multiple input branches."""
def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
super().__init__([input_shape], output_shape, layer_configuration, optimizer=optimizer, **kwargs)
def set_model(self):
x_input = []
x_in = []
stop_pos = None
for branch in range(len(self._input_shape)):
print(branch)
shape_b = self._input_shape[branch]
x_input_b = keras.layers.Input(shape=shape_b, name=f"input_branch{branch + 1}")
x_input.append(x_input_b)
x_in_b = x_input_b
b_conf = copy.deepcopy(self.conf)
for pos, layer_opts in enumerate(b_conf):
print(layer_opts)
if layer_opts.get("type") == "Concatenate":
if stop_pos is None:
stop_pos = pos
else:
assert pos == stop_pos
break
layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
x_in_b = layer(**layer_kwargs, name=f"{layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b)
if follow_up_layer is not None:
x_in_b = follow_up_layer(name=f"{follow_up_layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b)
self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
"branch": branch})
x_in.append(x_in_b)
print("concat")
x_concat = keras.layers.Concatenate()(x_in)
if stop_pos is not None:
for layer_opts in self.conf[stop_pos + 1:]:
print(layer_opts)
layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
x_concat = layer(**layer_kwargs)(x_concat)
if follow_up_layer is not None:
x_concat = follow_up_layer()(x_concat)
self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
"branch": "concat"})
x_concat = keras.layers.Dense(self._output_shape)(x_concat)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
self.model = keras.Model(inputs=x_input, outputs=[out])
print(self.model.summary())
class BranchedInputRNN(RNN): # pragma: no cover class BranchedInputRNN(RNN): # pragma: no cover
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment