Skip to content
Snippets Groups Projects
Commit 4ec0ab43 authored by leufen1's avatar leufen1
Browse files

introduce new class BranchedInputCNN

parent c0ac118c
No related branches found
No related tags found
5 merge requests!430update recent developments,!413update release branch,!412Resolve "release v2.0.0",!406Lukas issue368 feat prepare cnn class for filter benchmarking,!403Resolve "prepare CNN class for filter benchmarking"
Pipeline #94354 passed
from functools import partial, reduce from functools import partial, reduce
import copy
from tensorflow import keras as keras from tensorflow import keras as keras
...@@ -6,6 +7,63 @@ from mlair import AbstractModelClass ...@@ -6,6 +7,63 @@ from mlair import AbstractModelClass
from mlair.helpers import select_from_dict from mlair.helpers import select_from_dict
from mlair.model_modules.loss import var_loss from mlair.model_modules.loss import var_loss
from mlair.model_modules.recurrent_networks import RNN from mlair.model_modules.recurrent_networks import RNN
from mlair.model_modules.convolutional_networks import CNNfromConfig
class BranchedInputCNN(CNNfromConfig): # pragma: no cover
"""A convolutional neural network with multiple input branches."""
def __init__(self, input_shape: list, output_shape: list, layer_configuration: list, optimizer="adam", **kwargs):
super().__init__([input_shape], output_shape, layer_configuration, optimizer=optimizer, **kwargs)
def set_model(self):
x_input = []
x_in = []
stop_pos = None
for branch in range(len(self._input_shape)):
print(branch)
shape_b = self._input_shape[branch]
x_input_b = keras.layers.Input(shape=shape_b, name=f"input_branch{branch + 1}")
x_input.append(x_input_b)
x_in_b = x_input_b
b_conf = copy.deepcopy(self.conf)
for pos, layer_opts in enumerate(b_conf):
print(layer_opts)
if layer_opts.get("type") == "Concatenate":
if stop_pos is None:
stop_pos = pos
else:
assert pos == stop_pos
break
layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
x_in_b = layer(**layer_kwargs, name=f"{layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b)
if follow_up_layer is not None:
x_in_b = follow_up_layer(name=f"{follow_up_layer.__name__}_branch{branch + 1}_{pos + 1}")(x_in_b)
self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
"branch": branch})
x_in.append(x_in_b)
print("concat")
x_concat = keras.layers.Concatenate()(x_in)
if stop_pos is not None:
for layer_opts in self.conf[stop_pos + 1:]:
print(layer_opts)
layer, layer_kwargs, follow_up_layer = self._extract_layer_conf(layer_opts)
x_concat = layer(**layer_kwargs)(x_concat)
if follow_up_layer is not None:
x_concat = follow_up_layer()(x_concat)
self._layer_save.append({"layer": layer, **layer_kwargs, "follow_up_layer": follow_up_layer,
"branch": "concat"})
x_concat = keras.layers.Dense(self._output_shape)(x_concat)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
self.model = keras.Model(inputs=x_input, outputs=[out])
print(self.model.summary())
class BranchedInputRNN(RNN): # pragma: no cover class BranchedInputRNN(RNN): # pragma: no cover
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment