Skip to content
Snippets Groups Projects
Commit 38849c46 authored by lukas leufen's avatar lukas leufen
Browse files

loss function is implemented with docs and test, #5

parent e5616b9b
Branches
Tags
2 merge requests!9new version v0.2.0,!7l_p_loss and lrdecay implementation
Checking pipeline status
......@@ -14,7 +14,13 @@ def to_list(arg):
return arg
def l_p_loss(power):
def l_p_loss(power: int):
"""
Calculate the L<p> loss for given power p. L1 (p=1) is equal to mean absolute error (MAE), L2 (p=2) is to mean
squared error (MSE), ...
:param power: set the power of the error calculus
:return: loss for given power
"""
def loss(y_true, y_pred):
return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1)
return loss
......
......@@ -13,5 +13,9 @@ class TestLoss:
model = keras.Sequential()
model.add(keras.layers.Lambda(lambda x: x, input_shape=(None, )))
model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
hist = model.fit(np.array([1, 0]), np.array([1, 1]), epochs=1)
assert hist.history['loss'][0] == 0.5
hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
assert hist.history['loss'][0] == 1.25
model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(3))
hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1)
assert hist.history['loss'][0] == 2.25
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment