From 19e758cc206eae3956e3cdb2e9c86cf219d4d4cb Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Thu, 27 May 2021 11:47:52 +0200
Subject: [PATCH] split activation for rnn and dense layer part

---
 mlair/model_modules/recurrent_networks.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlair/model_modules/recurrent_networks.py b/mlair/model_modules/recurrent_networks.py
index 0861d416..cbe5d145 100644
--- a/mlair/model_modules/recurrent_networks.py
+++ b/mlair/model_modules/recurrent_networks.py
@@ -31,6 +31,7 @@ class RNN(AbstractModelClass):
     _rnn = {"lstm": keras.layers.LSTM, "gru": keras.layers.GRU}
 
     def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
+                 activation_rnn="tanh",
                  optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
                  batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs):
         """
@@ -68,6 +69,8 @@ class RNN(AbstractModelClass):
         # settings
         self.activation = self._set_activation(activation.lower())
         self.activation_name = activation
+        self.activation_rnn = self._set_activation(activation_rnn.lower())
+        self.activation_rnn_name = activation
         self.activation_output = self._set_activation(activation_output.lower())
         self.activation_output_name = activation_output
         self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs)
@@ -76,7 +79,7 @@ class RNN(AbstractModelClass):
         self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
         self.RNN = self._rnn.get(rnn_type.lower())
         self._update_model_name(rnn_type)
-        # self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
+        self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
         # self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
         self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
 
@@ -105,12 +108,13 @@ class RNN(AbstractModelClass):
             x_in = self.RNN(n_hidden, return_sequences=return_sequences)(x_in)
             if self.bn is True:
                 x_in = keras.layers.BatchNormalization()(x_in)
-            x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
+            x_in = self.activation_rnn(name=f"{self.activation_rnn_name}_{layer + 1}")(x_in)
             if self.dropout is not None:
                 x_in = self.dropout(self.dropout_rate)(x_in)
 
         if self.add_dense_layer is True:
-            x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}")(x_in)
+            x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}",
+                                      kernel_initializer=self.kernel_initializer, )(x_in)
             x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in)
         x_in = keras.layers.Dense(self._output_shape)(x_in)
         out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
@@ -172,7 +176,6 @@ class RNN(AbstractModelClass):
     #         return reg(**reg_kwargs)
     #     except KeyError:
     #         raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
-    #
 
     def _update_model_name(self, rnn_type):
         n_input = str(reduce(lambda x, y: x * y, self._input_shape))
-- 
GitLab