diff --git a/mlair/model_modules/branched_input_networks.py b/mlair/model_modules/branched_input_networks.py index 8da3b17edc47fe281738bac54d26a28836716fac..2c62c3cafc1537979e4a21bdb3bb6aa798e6e193 100644 --- a/mlair/model_modules/branched_input_networks.py +++ b/mlair/model_modules/branched_input_networks.py @@ -44,7 +44,8 @@ class BranchedInputRNN(RNN): # pragma: no cover for layer, n_hidden in enumerate(conf): return_sequences = (layer < len(conf) - 1) x_in_b = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn, - name=f"{self.RNN.__name__}_branch{branch + 1}_{layer + 1}")(x_in_b) + name=f"{self.RNN.__name__}_branch{branch + 1}_{layer + 1}", + kernel_regularizer=self.kernel_regularizer)(x_in_b) if self.bn is True: x_in_b = keras.layers.BatchNormalization()(x_in_b) x_in_b = self.activation_rnn(name=f"{self.activation_rnn_name}_branch{branch + 1}_{layer + 1}")(x_in_b) diff --git a/mlair/model_modules/recurrent_networks.py b/mlair/model_modules/recurrent_networks.py index e909ae7696bdf90d4e9a95e020b75a97e15dfd50..13e6fbecc7f3936a788dd6b035b9a7abe7b42857 100644 --- a/mlair/model_modules/recurrent_networks.py +++ b/mlair/model_modules/recurrent_networks.py @@ -2,6 +2,7 @@ __author__ = "Lukas Leufen" __date__ = '2021-05-25' from functools import reduce, partial +from typing import Union from mlair.model_modules import AbstractModelClass from mlair.helpers import select_from_dict @@ -33,7 +34,8 @@ class RNN(AbstractModelClass): # pragma: no cover def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", activation_rnn="tanh", dropout_rnn=0, optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, layer_configuration=None, - batch_normalization=False, rnn_type="lstm", add_dense_layer=False, dense_layer_configuration=None, **kwargs): + batch_normalization=False, rnn_type="lstm", add_dense_layer=False, dense_layer_configuration=None, + kernel_regularizer=None, **kwargs): """ Sets model and loss depending on the given arguments. @@ -42,10 +44,12 @@ class RNN(AbstractModelClass): # pragma: no cover Customize this RNN model via the following parameters: - :param activation: set your desired activation function for appended dense layers (add_dense_layer=True=. Choose + :param activation: set your desired activation function for appended dense layers (add_dense_layer=True). Choose from relu, tanh, sigmoid, linear, selu, prelu, leakyrelu. (Default relu) :param activation_rnn: set your desired activation function of the rnn output. Choose from relu, tanh, sigmoid, - linear, selu, prelu, leakyrelu. (Default tanh) + linear, selu, prelu, leakyrelu. To use the fast cuDNN implementation, tensorflow requires to use tanh as + activation. Note that this is not the recurrent activation (which is not mutable in this class) but the + activation of the cell. (Default tanh) :param activation_output: same as activation parameter but exclusively applied on output layer only. (Default linear) :param optimizer: set optimizer method. Can be either adam or sgd. (Default adam) @@ -58,7 +62,8 @@ class RNN(AbstractModelClass): # pragma: no cover :param dropout: use dropout with given rate. If no value is provided, dropout layers are not added to the network at all. (Default None) :param dropout_rnn: use recurrent dropout with given rate. This is applied along the recursion and not after - a rnn layer. (Default 0) + a rnn layer. Be aware that tensorflow is only able to use the fast cuDNN implementation with no recurrent + dropout. (Default 0) :param batch_normalization: use batch normalization layer in the network if enabled. These layers are inserted between the linear part of a layer (the nn part) and the non-linear part (activation function). No BN layer is added if set to false. (Default false) @@ -94,7 +99,7 @@ class RNN(AbstractModelClass): # pragma: no cover self.RNN = self._rnn.get(rnn_type.lower()) self._update_model_name(rnn_type) self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") - # self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) + self.kernel_regularizer, self.kernel_regularizer_opts = self._set_regularizer(kernel_regularizer, **kwargs) self.dropout, self.dropout_rate = self._set_dropout(activation, dropout) assert 0 <= dropout_rnn <= 1 self.dropout_rnn = dropout_rnn @@ -121,7 +126,8 @@ class RNN(AbstractModelClass): # pragma: no cover for layer, n_hidden in enumerate(conf): return_sequences = (layer < len(conf) - 1) - x_in = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn)(x_in) + x_in = self.RNN(n_hidden, return_sequences=return_sequences, recurrent_dropout=self.dropout_rnn, + kernel_regularizer=self.kernel_regularizer)(x_in) if self.bn is True: x_in = keras.layers.BatchNormalization()(x_in) x_in = self.activation_rnn(name=f"{self.activation_rnn_name}_{layer + 1}")(x_in) @@ -188,23 +194,23 @@ class RNN(AbstractModelClass): # pragma: no cover return opt(**opt_kwargs) except KeyError: raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") - # - # def _set_regularizer(self, regularizer, **kwargs): - # if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): - # return None - # try: - # reg_name = regularizer.lower() - # reg = self._regularizer.get(reg_name) - # reg_kwargs = {} - # if reg_name in ["l1", "l2"]: - # reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True) - # if reg_name in reg_kwargs: - # reg_kwargs["l"] = reg_kwargs.pop(reg_name) - # elif reg_name == "l1_l2": - # reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True) - # return reg(**reg_kwargs) - # except KeyError: - # raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") + + def _set_regularizer(self, regularizer: Union[None, str], **kwargs): + if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): + return None, None + try: + reg_name = regularizer.lower() + reg = self._regularizer.get(reg_name) + reg_kwargs = {} + if reg_name in ["l1", "l2"]: + reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True) + if reg_name in reg_kwargs: + reg_kwargs["l"] = reg_kwargs.pop(reg_name) + elif reg_name == "l1_l2": + reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True) + return reg(**reg_kwargs), reg_kwargs + except KeyError: + raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") def _update_model_name(self, rnn_type): n_input = str(reduce(lambda x, y: x * y, self._input_shape))