From 5e25dbc64ac03e49e4951d9a2971023270f36cd9 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Thu, 12 Mar 2020 09:32:36 +0100 Subject: [PATCH] update tests for paperModel --- test/test_model_modules/test_model_class.py | 33 +++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py index 0dbd2d9b..13f982b8 100644 --- a/test/test_model_modules/test_model_class.py +++ b/test/test_model_modules/test_model_class.py @@ -2,6 +2,7 @@ import keras import pytest from src.model_modules.model_class import AbstractModelClass +from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel class TestAbstractModelClass: @@ -27,3 +28,35 @@ class TestAbstractModelClass: assert hasattr(amc, "compile") is True assert hasattr(amc.model, "compile") is True assert amc.compile == amc.model.compile + + +class TestMyPaperModel: + + @pytest.fixture + def mpm(self): + return MyPaperModel(window_history_size=6, window_lead_time=4, channels=9) + + def test_init(self, mpm): + # check if loss number of loss functions fit to model outputs + # same loss fkts. for all tails or different fkts. per tail + if isinstance(mpm.model.output_shape, list): + assert (callable(mpm.loss) or (len(mpm.loss) == 1)) or (len(mpm.loss) == len(mpm.model.output_shape)) + elif isinstance(mpm.model.output_shape, tuple): + assert callable(mpm.loss) or (len(mpm.loss) == 1) + + def test_set_model(self, mpm): + assert isinstance(mpm.model, keras.Model) + assert mpm.model.layers[0].output_shape == (None, 7, 1, 9) + # check output dimensions + if isinstance(mpm.model.output_shape, tuple): + assert mpm.model.output_shape == (None, 4) + elif isinstance(mpm.model.output_shape, list): + for tail_shape in mpm.model.output_shape: + assert tail_shape == (None, 4) + else: + raise TypeError(f"Type of model.output_shape as to be a tuple (one tail)" + f" or a list of tuples (multiple tails). Received: {type(mpm.model.output_shape)}") + + def test_set_loss(self, mpm): + assert callable(mpm.loss) or (len(mpm.loss) > 0) + -- GitLab