Skip to content
Snippets Groups Projects
Commit 3a94611e authored by leufen1's avatar leufen1
Browse files

add dense layer between rnn and output

parent 8da0c387
No related branches found
No related tags found
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!295Resolve "data handler FIR filter",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #68647 passed
...@@ -32,7 +32,7 @@ class RNN(AbstractModelClass): ...@@ -32,7 +32,7 @@ class RNN(AbstractModelClass):
def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None, optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
batch_normalization=False, rnn_type="lstm", **kwargs): batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs):
""" """
Sets model and loss depending on the given arguments. Sets model and loss depending on the given arguments.
...@@ -72,6 +72,7 @@ class RNN(AbstractModelClass): ...@@ -72,6 +72,7 @@ class RNN(AbstractModelClass):
self.activation_output_name = activation_output self.activation_output_name = activation_output
self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs) self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs)
self.bn = batch_normalization self.bn = batch_normalization
self.add_dense_layer = add_dense_layer
self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
self.RNN = self._rnn.get(rnn_type.lower()) self.RNN = self._rnn.get(rnn_type.lower())
self._update_model_name(rnn_type) self._update_model_name(rnn_type)
...@@ -108,6 +109,9 @@ class RNN(AbstractModelClass): ...@@ -108,6 +109,9 @@ class RNN(AbstractModelClass):
if self.dropout is not None: if self.dropout is not None:
x_in = self.dropout(self.dropout_rate)(x_in) x_in = self.dropout(self.dropout_rate)(x_in)
if self.add_dense_layer is True:
x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), ame=f"Dense_{len(conf) + 1}")(x_in)
x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in)
x_in = keras.layers.Dense(self._output_shape)(x_in) x_in = keras.layers.Dense(self._output_shape)(x_in)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
self.model = keras.Model(inputs=x_input, outputs=[out]) self.model = keras.Model(inputs=x_input, outputs=[out])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment