Skip to content
Snippets Groups Projects
Commit 2a9a5d84 authored by leufen1's avatar leufen1
Browse files

corrected input variable

parent 2015c09d
No related branches found
No related tags found
6 merge requests!430update recent developments,!413update release branch,!412Resolve "release v2.0.0",!395Lukas issue362 feat branched rnn,!390Lukas issue362 feat branched rnn,!388Resolve "branched rnn model class"
Pipeline #92738 passed
......@@ -8,7 +8,7 @@ from mlair.model_modules.loss import var_loss
from mlair.model_modules.recurrent_networks import RNN
class BranchInputRNN(RNN): # pragma: no cover
class BranchedInputRNN(RNN): # pragma: no cover
"""A recurrent neural network with multiple input branches."""
def __init__(self, input_shape, output_shape, *args, **kwargs):
......@@ -37,7 +37,9 @@ class BranchInputRNN(RNN): # pragma: no cover
for branch in range(len(self._input_shape)):
shape_b = self._input_shape[branch]
x_input_b = keras.layers.Input(shape=shape_b)
x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])))(x_input_b)
x_input.append(x_input_b)
x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])),
name=f"reshape_branch{branch + 1}")(x_input_b)
for layer, n_hidden in enumerate(conf):
return_sequences = (layer < len(conf) - 1)
......@@ -69,7 +71,7 @@ class BranchInputRNN(RNN): # pragma: no cover
x_concat = self.dropout(self.dropout_rate)(x_concat)
x_concat = keras.layers.Dense(self._output_shape)(x_concat)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
self.model = keras.Model(inputs=x_input, outputs=[out])
print(self.model.summary())
......@@ -78,8 +80,8 @@ class BranchInputRNN(RNN): # pragma: no cover
"metrics": ["mse", "mae", var_loss]}
def _update_model_name(self, rnn_type):
# n_input = str(reduce(lambda x, y: x * y, self._input_shape))
n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}"
n_input = f"{len(self._input_shape)}x{self._input_shape[0][0]}x" \
f"{str(reduce(lambda x, y: x * y, self._input_shape[0][1:]))}"
n_output = str(self._output_shape)
self.model_name = rnn_type.upper()
if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment