From 2a9a5d84b7e904a27b37149c90bb2494e4fd18a5 Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Tue, 22 Feb 2022 09:20:29 +0100 Subject: [PATCH] corrected input variable --- mlair/model_modules/branched_input_networks.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py index 51bef4f9..4345d48b 100644 --- a/mlair/model_modules/branched_input_networks.py +++ b/mlair/model_modules/branched_input_networks.py @@ -8,7 +8,7 @@ from mlair.model_modules.loss import var_loss from mlair.model_modules.recurrent_networks import RNN -class BranchInputRNN(RNN): # pragma: no cover +class BranchedInputRNN(RNN): # pragma: no cover """A recurrent neural network with multiple input branches.""" def __init__(self, input_shape, output_shape, *args, **kwargs): @@ -37,7 +37,9 @@ class BranchInputRNN(RNN): # pragma: no cover for branch in range(len(self._input_shape)): shape_b = self._input_shape[branch] x_input_b = keras.layers.Input(shape=shape_b) - x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])))(x_input_b) + x_input.append(x_input_b) + x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])), + name=f"reshape_branch{branch + 1}")(x_input_b) for layer, n_hidden in enumerate(conf): return_sequences = (layer < len(conf) - 1) @@ -69,7 +71,7 @@ class BranchInputRNN(RNN): # pragma: no cover x_concat = self.dropout(self.dropout_rate)(x_concat) x_concat = keras.layers.Dense(self._output_shape)(x_concat) - out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) + out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat) self.model = keras.Model(inputs=x_input, outputs=[out]) print(self.model.summary()) @@ -78,8 +80,8 @@ class BranchInputRNN(RNN): # pragma: no cover "metrics": ["mse", "mae", var_loss]} def _update_model_name(self, rnn_type): - # n_input = str(reduce(lambda x, y: x * y, self._input_shape)) - n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}" + n_input = f"{len(self._input_shape)}x{self._input_shape[0][0]}x" \ + f"{str(reduce(lambda x, y: x * y, self._input_shape[0][1:]))}" n_output = str(self._output_shape) self.model_name = rnn_type.upper() if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2: -- GitLab