From e5616b9bc26f5b881db25870bdb510081b7e4978 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 12 Nov 2019 09:31:41 +0100
Subject: [PATCH] simple test for l_p_loss

---
 src/helpers.py       | 10 ++++------
 test/test_helpers.py | 17 +++++++++++++++++
 2 files changed, 21 insertions(+), 6 deletions(-)
 create mode 100644 test/test_helpers.py

diff --git a/src/helpers.py b/src/helpers.py
index b4cbbcca..ec6aeb56 100644
--- a/src/helpers.py
+++ b/src/helpers.py
@@ -14,12 +14,10 @@ def to_list(arg):
     return arg
 
 
-class Loss:
-
-    def l_p_loss(self, power):
-        def loss(y_true, y_pred):
-            return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1)
-        return loss
+def l_p_loss(power):
+    def loss(y_true, y_pred):
+        return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1)
+    return loss
 
 
 class lrDecay(keras.callbacks.History):
diff --git a/test/test_helpers.py b/test/test_helpers.py
new file mode 100644
index 00000000..163d2682
--- /dev/null
+++ b/test/test_helpers.py
@@ -0,0 +1,17 @@
+import pytest
+from src.helpers import l_p_loss
+import logging
+import os
+import keras
+import keras.backend as K
+import numpy as np
+
+
+class TestLoss:
+
+    def test_l_p_loss(self):
+        model = keras.Sequential()
+        model.add(keras.layers.Lambda(lambda x: x, input_shape=(None, )))
+        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
+        hist = model.fit(np.array([1, 0]), np.array([1, 1]), epochs=1)
+        assert hist.history['loss'][0] == 0.5
-- 
GitLab