From e5616b9bc26f5b881db25870bdb510081b7e4978 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Tue, 12 Nov 2019 09:31:41 +0100 Subject: [PATCH] simple test for l_p_loss --- src/helpers.py | 10 ++++------ test/test_helpers.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) create mode 100644 test/test_helpers.py diff --git a/src/helpers.py b/src/helpers.py index b4cbbcca..ec6aeb56 100644 --- a/src/helpers.py +++ b/src/helpers.py @@ -14,12 +14,10 @@ def to_list(arg): return arg -class Loss: - - def l_p_loss(self, power): - def loss(y_true, y_pred): - return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1) - return loss +def l_p_loss(power): + def loss(y_true, y_pred): + return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1) + return loss class lrDecay(keras.callbacks.History): diff --git a/test/test_helpers.py b/test/test_helpers.py new file mode 100644 index 00000000..163d2682 --- /dev/null +++ b/test/test_helpers.py @@ -0,0 +1,17 @@ +import pytest +from src.helpers import l_p_loss +import logging +import os +import keras +import keras.backend as K +import numpy as np + + +class TestLoss: + + def test_l_p_loss(self): + model = keras.Sequential() + model.add(keras.layers.Lambda(lambda x: x, input_shape=(None, ))) + model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2)) + hist = model.fit(np.array([1, 0]), np.array([1, 1]), epochs=1) + assert hist.history['loss'][0] == 0.5 -- GitLab