From 3a94611ecb5c13f0afd93baafdd66bd1c206a0fe Mon Sep 17 00:00:00 2001 From: leufen1 <l.leufen@fz-juelich.de> Date: Thu, 27 May 2021 11:16:27 +0200 Subject: [PATCH] add dense layer between rnn and output --- mlair/model_modules/recurrent_networks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlair/model_modules/recurrent_networks.py b/mlair/model_modules/recurrent_networks.py index 7adc9111..ab28085b 100644 --- a/mlair/model_modules/recurrent_networks.py +++ b/mlair/model_modules/recurrent_networks.py @@ -32,7 +32,7 @@ class RNN(AbstractModelClass): def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None, - batch_normalization=False, rnn_type="lstm", **kwargs): + batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs): """ Sets model and loss depending on the given arguments. @@ -72,6 +72,7 @@ class RNN(AbstractModelClass): self.activation_output_name = activation_output self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs) self.bn = batch_normalization + self.add_dense_layer = add_dense_layer self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration self.RNN = self._rnn.get(rnn_type.lower()) self._update_model_name(rnn_type) @@ -108,6 +109,9 @@ class RNN(AbstractModelClass): if self.dropout is not None: x_in = self.dropout(self.dropout_rate)(x_in) + if self.add_dense_layer is True: + x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), ame=f"Dense_{len(conf) + 1}")(x_in) + x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in) x_in = keras.layers.Dense(self._output_shape)(x_in) out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) self.model = keras.Model(inputs=x_input, outputs=[out]) -- GitLab