From 2a9a5d84b7e904a27b37149c90bb2494e4fd18a5 Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Tue, 22 Feb 2022 09:20:29 +0100
Subject: [PATCH] corrected input variable

---
 mlair/model_modules/branched_input_networks.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py
index 51bef4f9..4345d48b 100644
--- a/mlair/model_modules/branched_input_networks.py
+++ b/mlair/model_modules/branched_input_networks.py
@@ -8,7 +8,7 @@ from mlair.model_modules.loss import var_loss
 from mlair.model_modules.recurrent_networks import RNN
 
 
-class BranchInputRNN(RNN):  # pragma: no cover
+class BranchedInputRNN(RNN):  # pragma: no cover
     """A recurrent neural network with multiple input branches."""
 
     def __init__(self, input_shape, output_shape, *args, **kwargs):
@@ -37,7 +37,9 @@ class BranchInputRNN(RNN):  # pragma: no cover
         for branch in range(len(self._input_shape)):
             shape_b = self._input_shape[branch]
             x_input_b = keras.layers.Input(shape=shape_b)
-            x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])))(x_input_b)
+            x_input.append(x_input_b)
+            x_in_b = keras.layers.Reshape((shape_b[0], reduce((lambda x, y: x * y), shape_b[1:])),
+                                          name=f"reshape_branch{branch + 1}")(x_input_b)
 
             for layer, n_hidden in enumerate(conf):
                 return_sequences = (layer < len(conf) - 1)
@@ -69,7 +71,7 @@ class BranchInputRNN(RNN):  # pragma: no cover
                         x_concat = self.dropout(self.dropout_rate)(x_concat)
 
         x_concat = keras.layers.Dense(self._output_shape)(x_concat)
-        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
+        out = self.activation_output(name=f"{self.activation_output_name}_output")(x_concat)
         self.model = keras.Model(inputs=x_input, outputs=[out])
         print(self.model.summary())
 
@@ -78,8 +80,8 @@ class BranchInputRNN(RNN):  # pragma: no cover
                                 "metrics": ["mse", "mae", var_loss]}
 
     def _update_model_name(self, rnn_type):
-        # n_input = str(reduce(lambda x, y: x * y, self._input_shape))
-        n_input = f"{len(self._input_shape)}x{str(reduce(lambda x, y: x * y, self._input_shape[0]))}"
+        n_input = f"{len(self._input_shape)}x{self._input_shape[0][0]}x" \
+                  f"{str(reduce(lambda x, y: x * y, self._input_shape[0][1:]))}"
         n_output = str(self._output_shape)
         self.model_name = rnn_type.upper()
         if isinstance(self.layer_configuration, tuple) and len(self.layer_configuration) == 2:
-- 
GitLab