Skip to content
Snippets Groups Projects
Commit 2a9a5d84 authored by leufen1's avatar leufen1
Browse files

corrected input variable

parent 2015c09d
No related branches found
No related tags found
6 merge requests!430update recent developments,!413update release branch,!412Resolve "release v2.0.0",!395Lukas issue362 feat branched rnn,!390Lukas issue362 feat branched rnn,!388Resolve "branched rnn model class"
Pipeline #92738 passed
...@@ -8,7 +8,7 @@ from mlair.model_modules.loss import var_loss ...@@ -8,7 +8,7 @@ from mlair.model_modules.loss import var_loss
from mlair.model_modules.recurrent_networks import RNN from mlair.model_modules.recurrent_networks import RNN
class BranchInputRNN(RNN): # pragma: no cover class BranchedInputRNN(RNN): # pragma: no cover
"""A recurrent neural network with multiple input branches.""" """A recurrent neural network with multiple input branches."""
def __init__(self, input_shape, output_shape, *args, **kwargs): def __init__(self, input_shape, output_shape, *args, **kwargs):
...@@ -37,7 +37,9 @@ class BranchInputRNN(RNN): # pragma: no cover ...@@ -37,7 +37,9 @@ class BranchInputRNN(RNN): # pragma: no cover
for branch in range(len(self._input_shape)): for branch in range(len(self._input_shape)):
shape_b = self._input_shape[branch] shape_b = self._input_shape[branch]
x_input_b = keras.layers.Input(shape=shape_b) x_input_b = keras.layers.Input(shape=shape_b)
x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])))(x_input_b) x_input.append(x_input_b)
x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])),
name=f"reshape_branch{branch + 1}")(x_input_b)
for layer, n_hidden in enumerate(conf): for layer, n_hidden in enumerate(conf):
return_sequences = (layer < len(conf) - 1) return_sequences = (layer < len(conf) - 1)
...@@ -69,7 +71,7 @@ class BranchInputRNN(RNN): # pragma: no cover ...@@ -69,7 +71,7 @@ class BranchInputRNN(RNN): # pragma: no cover
x_concat = self.dropout(self.dropout_rate)(x_concat) x_concat = self.dropout(self.dropout_rate)(x_concat)
x_concat = keras.layers.Dense(self._output_shape)(x_concat) x_concat = keras.layers.Dense(self._output_shape)(x_concat)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
self.model = keras.Model(inputs=x_input, outputs=[out]) self.model = keras.Model(inputs=x_input, outputs=[out])
print(self.model.summary()) print(self.model.summary())
...@@ -78,8 +80,8 @@ class BranchInputRNN(RNN): # pragma: no cover ...@@ -78,8 +80,8 @@ class BranchInputRNN(RNN): # pragma: no cover
"metrics": ["mse", "mae", var_loss]} "metrics": ["mse", "mae", var_loss]}
def _update_model_name(self, rnn_type): def _update_model_name(self, rnn_type):
# n_input = str(reduce(lambda x, y: x * y, self._input_shape)) n_input = f"{len(self._input_shape)}x{self._input_shape[0][0]}x" \
n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}" f"{str(reduce(lambda x, y: x * y, self._input_shape[0][1:]))}"
n_output = str(self._output_shape) n_output = str(self._output_shape)
self.model_name = rnn_type.upper() self.model_name = rnn_type.upper()
if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2: if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment