diff --git a/mlair/model_modules/fully_connected_networks.py b/mlair/model_modules/fully_connected_networks.py index dbcd3a9f41ca1b9a7435be95b93eb40c2b37c5a0..45b8eb63ca7afc6e1de5cd3767bd3a11280d3b43 100644 --- a/mlair/model_modules/fully_connected_networks.py +++ b/mlair/model_modules/fully_connected_networks.py @@ -5,6 +5,7 @@ from functools import reduce, partial from mlair.model_modules import AbstractModelClass from mlair.helpers import select_from_dict +from mlair.model_modules.loss import var_loss, custom_loss import keras @@ -64,7 +65,9 @@ class FCN(AbstractModelClass): _activation = {"relu": keras.layers.ReLU, "tanh": partial(keras.layers.Activation, "tanh"), "sigmoid": partial(keras.layers.Activation, "sigmoid"), - "linear": partial(keras.layers.Activation, "linear")} + "linear": partial(keras.layers.Activation, "linear"), + "selu": partial(keras.layers.Activation, "selu")} + _initializer = {"selu": keras.initializers.lecun_normal()} _optimizer = {"adam": keras.optimizers.adam, "sgd": keras.optimizers.SGD} _requirements = ["lr", "beta_1", "beta_2", "epsilon", "decay", "amsgrad", "momentum", "nesterov"] @@ -87,6 +90,7 @@ class FCN(AbstractModelClass): self.optimizer = self._set_optimizer(optimizer, **kwargs) self.layer_configuration = (n_layer, n_hidden) self._update_model_name() + self.kernel_initializer = self._initializer.get(activation, "glorot_uniform") # apply to model self.set_model() @@ -126,11 +130,12 @@ class FCN(AbstractModelClass): x_in = keras.layers.Flatten()(x_input) n_layer, n_hidden = self.layer_configuration for layer in range(n_layer): - x_in = keras.layers.Dense(n_hidden)(x_in) + x_in = keras.layers.Dense(n_hidden, kernel_initializer=self.kernel_initializer)(x_in) x_in = self.activation()(x_in) x_in = keras.layers.Dense(self._output_shape)(x_in) out = self.activation_output()(x_in) self.model = keras.Model(inputs=x_input, outputs=[out]) def set_compile_options(self): - self.compile_options = {"loss": [keras.losses.mean_squared_error], "metrics": ["mse", "mae"]} + self.compile_options = {"loss": [custom_loss([keras.losses.mean_squared_error, var_loss])], + "metrics": ["mse", "mae", var_loss]} diff --git a/mlair/model_modules/loss.py b/mlair/model_modules/loss.py index bcb85282d0fa15f18ebd65a89e4020c2a0170224..ba871e983ecfa1e91676d53b834ebd622c00fe49 100644 --- a/mlair/model_modules/loss.py +++ b/mlair/model_modules/loss.py @@ -20,3 +20,21 @@ def l_p_loss(power: int) -> Callable: return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1) return loss + + +def var_loss(y_true, y_pred) -> Callable: + return K.mean(K.square(K.var(y_true) - K.var(y_pred))) + + +def custom_loss(loss_list, loss_weights=None) -> Callable: + n = len(loss_list) + if loss_weights is None: + loss_weights = [1. / n for _ in range(n)] + else: + assert len(loss_weights) == n + loss_weights = [w / sum(loss_weights) for w in loss_weights] + + def loss(y_true, y_pred): + return sum([loss_weights[i] * loss_list[i](y_true, y_pred) for i in range(n)]) + + return loss diff --git a/test/test_model_modules/test_loss.py b/test/test_model_modules/test_loss.py index e54e0b00de4a71d241f30e0b6b0c1a2e8fa1a19c..c993830c5290c9beeec392dfd806354ca02eb490 100644 --- a/test/test_model_modules/test_loss.py +++ b/test/test_model_modules/test_loss.py @@ -1,10 +1,12 @@ import keras import numpy as np -from mlair.model_modules.loss import l_p_loss +from mlair.model_modules.loss import l_p_loss, var_loss, custom_loss +import pytest -class TestLoss: + +class TestLPLoss: def test_l_p_loss(self): model = keras.Sequential() @@ -14,4 +16,42 @@ class TestLoss: assert hist.history['loss'][0] == 1.25 model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(3)) hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) - assert hist.history['loss'][0] == 2.25 \ No newline at end of file + assert hist.history['loss'][0] == 2.25 + + +class TestVarLoss: + + def test_var_loss(self): + model = keras.Sequential() + model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,))) + model.compile(optimizer=keras.optimizers.Adam(), loss=var_loss) + hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) + assert hist.history['loss'][0] == 0.140625 + + +class TestCustomLoss: + + def test_custom_loss_no_weights(self): + cust_loss = custom_loss([l_p_loss(2), var_loss]) + model = keras.Sequential() + model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,))) + model.compile(optimizer=keras.optimizers.Adam(), loss=cust_loss) + hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) + assert hist.history['loss'][0] == (0.5 * 0.140625 + 0.5 * 1.25) + + @pytest.mark.parametrize("weights", [[0.3, 0.7], [0.5, 0.5], [1, 1], [4, 1]]) + def test_custom_loss_with_weights(self, weights): + cust_loss = custom_loss([l_p_loss(2), var_loss], weights) + model = keras.Sequential() + model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,))) + model.compile(optimizer=keras.optimizers.Adam(), loss=cust_loss) + hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) + weights_adjusted = list(map(lambda x: x / sum(weights), weights)) + expected = (weights_adjusted[0] * 1.25 + weights_adjusted[1] * 0.140625) + assert np.testing.assert_almost_equal(hist.history['loss'][0], expected, decimal=6) is None + + def test_custom_loss_invalid_weights(self): + with pytest.raises(AssertionError): + custom_loss([l_p_loss(2), var_loss], [0.3]) + with pytest.raises(AssertionError): + custom_loss([l_p_loss(2), var_loss], [0.4, 3, 1])