diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py index 0dbd2d9b67a0748bf09eb4f59e1888aae1ea405d..13f982b80906d8d5d6beae7075b23f4c84d6edd1 100644 --- a/test/test_model_modules/test_model_class.py +++ b/test/test_model_modules/test_model_class.py @@ -2,6 +2,7 @@ import keras import pytest from src.model_modules.model_class import AbstractModelClass +from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel class TestAbstractModelClass: @@ -27,3 +28,35 @@ class TestAbstractModelClass: assert hasattr(amc, "compile") is True assert hasattr(amc.model, "compile") is True assert amc.compile == amc.model.compile + + +class TestMyPaperModel: + + @pytest.fixture + def mpm(self): + return MyPaperModel(window_history_size=6, window_lead_time=4, channels=9) + + def test_init(self, mpm): + # check if loss number of loss functions fit to model outputs + # same loss fkts. for all tails or different fkts. per tail + if isinstance(mpm.model.output_shape, list): + assert (callable(mpm.loss) or (len(mpm.loss) == 1)) or (len(mpm.loss) == len(mpm.model.output_shape)) + elif isinstance(mpm.model.output_shape, tuple): + assert callable(mpm.loss) or (len(mpm.loss) == 1) + + def test_set_model(self, mpm): + assert isinstance(mpm.model, keras.Model) + assert mpm.model.layers[0].output_shape == (None, 7, 1, 9) + # check output dimensions + if isinstance(mpm.model.output_shape, tuple): + assert mpm.model.output_shape == (None, 4) + elif isinstance(mpm.model.output_shape, list): + for tail_shape in mpm.model.output_shape: + assert tail_shape == (None, 4) + else: + raise TypeError(f"Type of model.output_shape as to be a tuple (one tail)" + f" or a list of tuples (multiple tails). Received: {type(mpm.model.output_shape)}") + + def test_set_loss(self, mpm): + assert callable(mpm.loss) or (len(mpm.loss) > 0) +