Skip to content
Snippets Groups Projects
Commit 19e758cc authored by leufen1's avatar leufen1
Browse files

split activation for rnn and dense layer part

parent 97ca2fa3
Branches
Tags
5 merge requests!319add all changes of dev into release v1.4.0 branch,!318Resolve "release v1.4.0",!317enabled window_lead_time=1,!295Resolve "data handler FIR filter",!259Draft: Resolve "WRF-Datahandler should inherit from SingleStationDatahandler"
Pipeline #68653 passed
...@@ -31,6 +31,7 @@ class RNN(AbstractModelClass): ...@@ -31,6 +31,7 @@ class RNN(AbstractModelClass):
_rnn = {"lstm": keras.layers.LSTM, "gru": keras.layers.GRU} _rnn = {"lstm": keras.layers.LSTM, "gru": keras.layers.GRU}
def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear",
activation_rnn="tanh",
optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None, optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None,
batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs): batch_normalization=False, rnn_type="lstm", add_dense_layer=False, **kwargs):
""" """
...@@ -68,6 +69,8 @@ class RNN(AbstractModelClass): ...@@ -68,6 +69,8 @@ class RNN(AbstractModelClass):
# settings # settings
self.activation = self._set_activation(activation.lower()) self.activation = self._set_activation(activation.lower())
self.activation_name = activation self.activation_name = activation
self.activation_rnn = self._set_activation(activation_rnn.lower())
self.activation_rnn_name = activation
self.activation_output = self._set_activation(activation_output.lower()) self.activation_output = self._set_activation(activation_output.lower())
self.activation_output_name = activation_output self.activation_output_name = activation_output
self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs) self.optimizer = self._set_optimizer(optimizer.lower(), **kwargs)
...@@ -76,7 +79,7 @@ class RNN(AbstractModelClass): ...@@ -76,7 +79,7 @@ class RNN(AbstractModelClass):
self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration self.layer_configuration = (n_layer, n_hidden) if layer_configuration is None else layer_configuration
self.RNN = self._rnn.get(rnn_type.lower()) self.RNN = self._rnn.get(rnn_type.lower())
self._update_model_name(rnn_type) self._update_model_name(rnn_type)
# self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") self.kernel_initializer = self._initializer.get(activation, "glorot_uniform")
# self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) # self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs)
self.dropout, self.dropout_rate = self._set_dropout(activation, dropout) self.dropout, self.dropout_rate = self._set_dropout(activation, dropout)
...@@ -105,12 +108,13 @@ class RNN(AbstractModelClass): ...@@ -105,12 +108,13 @@ class RNN(AbstractModelClass):
x_in = self.RNN(n_hidden, return_sequences=return_sequences)(x_in) x_in = self.RNN(n_hidden, return_sequences=return_sequences)(x_in)
if self.bn is True: if self.bn is True:
x_in = keras.layers.BatchNormalization()(x_in) x_in = keras.layers.BatchNormalization()(x_in)
x_in = self.activation(name=f"{self.activation_name}_{layer + 1}")(x_in) x_in = self.activation_rnn(name=f"{self.activation_rnn_name}_{layer + 1}")(x_in)
if self.dropout is not None: if self.dropout is not None:
x_in = self.dropout(self.dropout_rate)(x_in) x_in = self.dropout(self.dropout_rate)(x_in)
if self.add_dense_layer is True: if self.add_dense_layer is True:
x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}")(x_in) x_in = keras.layers.Dense(min(self._output_shape ** 2, conf[-1]), name=f"Dense_{len(conf) + 1}",
kernel_initializer=self.kernel_initializer, )(x_in)
x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in) x_in = self.activation(name=f"{self.activation_name}_{len(conf) + 1}")(x_in)
x_in = keras.layers.Dense(self._output_shape)(x_in) x_in = keras.layers.Dense(self._output_shape)(x_in)
out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in) out = self.activation_output(name=f"{self.activation_output_name}_output")(x_in)
...@@ -172,7 +176,6 @@ class RNN(AbstractModelClass): ...@@ -172,7 +176,6 @@ class RNN(AbstractModelClass):
# return reg(**reg_kwargs) # return reg(**reg_kwargs)
# except KeyError: # except KeyError:
# raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") # raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.")
#
def _update_model_name(self, rnn_type): def _update_model_name(self, rnn_type):
n_input = str(reduce(lambda x, y: x * y, self._input_shape)) n_input = str(reduce(lambda x, y: x * y, self._input_shape))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment