diff --git a/src/helpers.py b/src/helpers.py index b4cbbccae75b7f48dd66cb316f38ad6e8dbb4e2f..ec6aeb5621a473f9f531cdaa3368dea9c4272ee6 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -14,12 +14,10 @@ def to_list(arg): return arg -class Loss: - - def l_p_loss(self, power): - def loss(y_true, y_pred): - return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1) - return loss +def l_p_loss(power): + def loss(y_true, y_pred): + return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1) + return loss class lrDecay(keras.callbacks.History): diff --git a/test/test_helpers.py b/test/test_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..163d2682d9b4b856afeb2f425484046ed3cb657f --- /dev/null +++ b/test/test_helpers.py @@ -0,0 +1,17 @@ +import pytest +from src.helpers import l_p_loss +import logging +import os +import keras +import keras.backend as K +import numpy as np + + +class TestLoss: + + def test_l_p_loss(self): + model = keras.Sequential() + model.add(keras.layers.Lambda(lambda x: x, input_shape=(None, ))) + model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2)) + hist = model.fit(np.array([1, 0]), np.array([1, 1]), epochs=1) + assert hist.history['loss'][0] == 0.5