diff --git a/mlair/helpers/helpers.py b/mlair/helpers/helpers.py index ee727ef59ff35334be0a52a4d78dbae814d6c205..b57b733b08c4635a16d7fd18e99538a991521fd8 100644 --- a/mlair/helpers/helpers.py +++ b/mlair/helpers/helpers.py @@ -103,7 +103,7 @@ def remove_items(obj: Union[List, Dict], items: Any): raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") -def select_from_dict(dict_obj: dict, sel_list: Any): +def select_from_dict(dict_obj: dict, sel_list: Any, remove_none=False): """ Extract all key values pairs whose key is contained in the sel_list. @@ -113,6 +113,7 @@ def select_from_dict(dict_obj: dict, sel_list: Any): sel_list = to_list(sel_list) assert isinstance(dict_obj, dict) sel_dict = {k: v for k, v in dict_obj.items() if k in sel_list} + sel_dict = sel_dict if not remove_none else {k: v for k, v in sel_dict.items() if v is not None} return sel_dict diff --git a/mlair/model_modules/fully_connected_networks.py b/mlair/model_modules/fully_connected_networks.py index 9e3657c36462255cacbb12b2720eb8243e37bc92..948d2b06d7bf801c585a4e7193a68cf75ada9e8a 100644 --- a/mlair/model_modules/fully_connected_networks.py +++ b/mlair/model_modules/fully_connected_networks.py @@ -69,10 +69,11 @@ class FCN(AbstractModelClass): "selu": partial(keras.layers.Activation, "selu")} _initializer = {"selu": keras.initializers.lecun_normal()} _optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD} - _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov"] + _regularizer = {"l1": keras.regularizers.l1, "l2": keras.regularizers.l2, "l1_l2": keras.regularizers.l1_l2} + _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov", "l1", "l2"] def __init__(self, input_shape: list, output_shape: list, activation="relu", activation_output="linear", - optimizer="adam", n_layer=1, n_hidden=10, **kwargs): + optimizer="adam", n_layer=1, n_hidden=10, regularizer=None, dropout=None, **kwargs): """ Sets model and loss depending on the given arguments. @@ -91,6 +92,8 @@ class FCN(AbstractModelClass): self.layer_configuration = (n_layer, n_hidden) self._update_model_name() self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") + self.kernel_regularizer = self._set_regularizer(regularizer, **kwargs) + self.dropout = self._set_dropout(dropout) # apply to model self.set_model() @@ -116,6 +119,30 @@ class FCN(AbstractModelClass): except KeyError: raise AttributeError(f"Given optimizer {optimizer} is not supported in this model class.") + def _set_regularizer(self, regularizer, **kwargs): + if regularizer is None: + return regularizer + try: + reg_name = regularizer.lower() + reg = self._regularizer.get(reg_name) + reg_kwargs = {} + if reg_name in ["l1", "l2"]: + reg_kwargs = select_from_dict(kwargs, reg_name, remove_none=True) + if reg_name in reg_kwargs: + reg_kwargs["l"] = reg_kwargs.pop(reg_name) + elif reg_name == "l1_l2": + reg_kwargs = select_from_dict(kwargs, ["l1", "l2"], remove_none=True) + return reg(**reg_kwargs) + except KeyError: + raise AttributeError(f"Given regularizer {regularizer} is not supported in this model class.") + + @staticmethod + def _set_dropout(dropout): + if dropout is None: + return dropout + assert 0 <= dropout < 1 + return dropout + def _update_model_name(self): n_layer, n_hidden = self.layer_configuration n_input = str(reduce(lambda x, y: x * y, self._input_shape)) @@ -130,8 +157,11 @@ class FCN(AbstractModelClass): x_in = keras.layers.Flatten()(x_input) n_layer, n_hidden = self.layer_configuration for layer in range(n_layer): - x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer)(x_in) + x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer)(x_in) x_in = self.activation()(x_in) + if self.dropout is not None: + x_in = keras.layers.Dropout(self.dropout)(x_in) x_in = keras.layers.Dense(self._output_shape)(x_in) out = self.activation_output()(x_in) self.model = keras.Model(inputs=x_input, outputs=[out]) diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py index f2e2b341afa424ce351c0253f41c75e362b77eba..91f2278ae7668b623f8d2434ebac7e959dc9c805 100644 --- a/test/test_helpers/test_helpers.py +++ b/test/test_helpers/test_helpers.py @@ -175,7 +175,7 @@ class TestSelectFromDict: @pytest.fixture def dictionary(self): - return {"a": 1, "b": 23, "c": "last"} + return {"a": 1, "b": 23, "c": "last", "e": None} def test_select(self, dictionary): assert select_from_dict(dictionary, "c") == {"c": "last"} @@ -186,6 +186,10 @@ class TestSelectFromDict: with pytest.raises(AssertionError): select_from_dict(["we"], "now") + def test_select_remove_none(self, dictionary): + assert select_from_dict(dictionary, ["a", "e"]) == {"a": 1, "e": None} + assert select_from_dict(dictionary, ["a", "e"], remove_none=True) == {"a": 1} + class TestRemoveItems: