From 5e25dbc64ac03e49e4951d9a2971023270f36cd9 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Thu, 12 Mar 2020 09:32:36 +0100
Subject: [PATCH] update tests for paperModel

---
 test/test_model_modules/test_model_class.py | 33 +++++++++++++++++++++
 1 file changed, 33 insertions(+)

diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py
index 0dbd2d9b..13f982b8 100644
--- a/test/test_model_modules/test_model_class.py
+++ b/test/test_model_modules/test_model_class.py
@@ -2,6 +2,7 @@ import keras
 import pytest
 
 from src.model_modules.model_class import AbstractModelClass
+from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel
 
 
 class TestAbstractModelClass:
@@ -27,3 +28,35 @@ class TestAbstractModelClass:
         assert hasattr(amc, "compile") is True
         assert hasattr(amc.model, "compile") is True
         assert amc.compile == amc.model.compile
+
+
+class TestMyPaperModel:
+
+    @pytest.fixture
+    def mpm(self):
+        return MyPaperModel(window_history_size=6, window_lead_time=4, channels=9)
+
+    def test_init(self, mpm):
+        # check if loss number of loss functions fit to model outputs
+        #       same loss fkts. for all tails               or different fkts. per tail
+        if isinstance(mpm.model.output_shape, list):
+            assert (callable(mpm.loss) or (len(mpm.loss) == 1)) or (len(mpm.loss) == len(mpm.model.output_shape))
+        elif isinstance(mpm.model.output_shape, tuple):
+            assert callable(mpm.loss) or (len(mpm.loss) == 1)
+
+    def test_set_model(self, mpm):
+        assert isinstance(mpm.model, keras.Model)
+        assert mpm.model.layers[0].output_shape == (None, 7, 1, 9)
+        # check output dimensions
+        if isinstance(mpm.model.output_shape, tuple):
+            assert mpm.model.output_shape == (None, 4)
+        elif isinstance(mpm.model.output_shape, list):
+            for tail_shape in mpm.model.output_shape:
+                assert tail_shape == (None, 4)
+        else:
+            raise TypeError(f"Type of model.output_shape as to be a tuple (one tail)"
+                            f" or a list of tuples (multiple tails). Received: {type(mpm.model.output_shape)}")
+
+    def test_set_loss(self, mpm):
+        assert callable(mpm.loss) or (len(mpm.loss) > 0)
+
-- 
GitLab