Skip to content
Snippets Groups Projects
Commit 3a94611e authored by leufen1's avatar leufen1
Browse files

add dense layer between rnn and output

parent 8da0c387
No related branches found
No related tags found
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!295Resolve "data handler FIR filter",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #68647 passed
......@@ -32,7 +32,7 @@ class RNN(AbstractModelClass):
def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
batch_normalization=False, rnn_type="lstm", **kwargs):
batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs):
"""
Sets model and loss depending on the given arguments.
......@@ -72,6 +72,7 @@ class RNN(AbstractModelClass):
self.activation_output_name = activation_output
self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs)
self.bn = batch_normalization
self.add_dense_layer = add_dense_layer
self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
self.RNN = self._rnn.get(rnn_type.lower())
self._update_model_name(rnn_type)
......@@ -108,6 +109,9 @@ class RNN(AbstractModelClass):
if self.dropout is not None:
x_in = self.dropout(self.dropout_rate)(x_in)
if self.add_dense_layer is True:
x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), ame=f"Dense_{len(conf) + 1}")(x_in)
x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in)
x_in = keras.layers.Dense(self._output_shape)(x_in)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
self.model = keras.Model(inputs=x_input, outputs=[out])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment