Skip to content
Snippets Groups Projects
Commit 19e758cc authored by leufen1's avatar leufen1
Browse files

split activation for rnn and dense layer part

parent 97ca2fa3
No related branches found
No related tags found
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!295Resolve "data handler FIR filter",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #68653 passed
......@@ -31,6 +31,7 @@ class RNN(AbstractModelClass):
_rnn = {"lstm": keras.layers.LSTM, "gru": keras.layers.GRU}
def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
activation_rnn="tanh",
optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs):
"""
......@@ -68,6 +69,8 @@ class RNN(AbstractModelClass):
# settings
self.activation = self._set_activation(activation.lower())
self.activation_name = activation
self.activation_rnn = self._set_activation(activation_rnn.lower())
self.activation_rnn_name = activation
self.activation_output = self._set_activation(activation_output.lower())
self.activation_output_name = activation_output
self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs)
......@@ -76,7 +79,7 @@ class RNN(AbstractModelClass):
self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
self.RNN = self._rnn.get(rnn_type.lower())
self._update_model_name(rnn_type)
# self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
# self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
......@@ -105,12 +108,13 @@ class RNN(AbstractModelClass):
x_in = self.RNN(n_hidden, return_sequences=return_sequences)(x_in)
if self.bn is True:
x_in = keras.layers.BatchNormalization()(x_in)
x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in)
x_in = self.activation_rnn(name=f"{self.activation_rnn_name}_{layer + 1}")(x_in)
if self.dropout is not None:
x_in = self.dropout(self.dropout_rate)(x_in)
if self.add_dense_layer is True:
x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}")(x_in)
x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}",
kernel_initializer=self.kernel_initializer, )(x_in)
x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in)
x_in = keras.layers.Dense(self._output_shape)(x_in)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
......@@ -172,7 +176,6 @@ class RNN(AbstractModelClass):
# return reg(**reg_kwargs)
# except KeyError:
# raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
#
def _update_model_name(self, rnn_type):
n_input = str(reduce(lambda x, y: x * y, self._input_shape))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment