diff --git a/mlair/model_modules/recurrent_networks.py b/mlair/model_modules/recurrent_networks.py index e65255a630a0af75f7a8760a676d83cd343ddded..13e6fbecc7f3936a788dd6b035b9a7abe7b42857 100644 --- a/mlair/model_modules/recurrent_networks.py +++ b/mlair/model_modules/recurrent_networks.py @@ -2,6 +2,7 @@ __author__ = "Lukas Leufen" __date__ = '2021-05-25' from functools import reduce, partial +from typing import Union from mlair.model_modules import AbstractModelClass from mlair.helpers import select_from_dict @@ -98,7 +99,7 @@ class RNN(AbstractModelClass): # pragma: no cover self.RNN = self._rnn.get(rnn_type.lower()) self._update_model_name(rnn_type) self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") - self.kernel_regularizer = self._set_regularizer(kernel_regularizer, **kwargs) + self.kernel_regularizer, self.kernel_regularizer_opts = self._set_regularizer(kernel_regularizer, **kwargs) self.dropout, self.dropout_rate = self._set_dropout(activation, dropout) assert 0 <= dropout_rnn <= 1 self.dropout_rnn = dropout_rnn @@ -193,23 +194,23 @@ class RNN(AbstractModelClass): # pragma: no cover return opt(**opt_kwargs) except KeyError: raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") - # - # def _set_regularizer(self, regularizer, **kwargs): - # if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): - # return None - # try: - # reg_name = regularizer.lower() - # reg = self._regularizer.get(reg_name) - # reg_kwargs = {} - # if reg_name in ["l1", "l2"]: - # reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True) - # if reg_name in reg_kwargs: - # reg_kwargs["l"] = reg_kwargs.pop(reg_name) - # elif reg_name == "l1_l2": - # reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True) - # return reg(**reg_kwargs) - # except KeyError: - # raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") + + def _set_regularizer(self, regularizer: Union[None, str], **kwargs): + if regularizer is None or (isinstance(regularizer, str) and regularizer.lower() == "none"): + return None, None + try: + reg_name = regularizer.lower() + reg = self._regularizer.get(reg_name) + reg_kwargs = {} + if reg_name in ["l1", "l2"]: + reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True) + if reg_name in reg_kwargs: + reg_kwargs["l"] = reg_kwargs.pop(reg_name) + elif reg_name == "l1_l2": + reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True) + return reg(**reg_kwargs), reg_kwargs + except KeyError: + raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") def _update_model_name(self, rnn_type): n_input = str(reduce(lambda x, y: x * y, self._input_shape))