Skip to content
Snippets Groups Projects
Commit 5e25dbc6 authored by Felix Kleinert's avatar Felix Kleinert
Browse files

update tests for paperModel

parent 7915b04e
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!69Felix issue077 feat model like in paper
Pipeline #31703 passed
......@@ -2,6 +2,7 @@ import keras
import pytest
from src.model_modules.model_class import AbstractModelClass
from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel
class TestAbstractModelClass:
......@@ -27,3 +28,35 @@ class TestAbstractModelClass:
assert hasattr(amc, "compile") is True
assert hasattr(amc.model, "compile") is True
assert amc.compile == amc.model.compile
class TestMyPaperModel:
@pytest.fixture
def mpm(self):
return MyPaperModel(window_history_size=6, window_lead_time=4, channels=9)
def test_init(self, mpm):
# check if loss number of loss functions fit to model outputs
# same loss fkts. for all tails or different fkts. per tail
if isinstance(mpm.model.output_shape, list):
assert (callable(mpm.loss) or (len(mpm.loss) == 1)) or (len(mpm.loss) == len(mpm.model.output_shape))
elif isinstance(mpm.model.output_shape, tuple):
assert callable(mpm.loss) or (len(mpm.loss) == 1)
def test_set_model(self, mpm):
assert isinstance(mpm.model, keras.Model)
assert mpm.model.layers[0].output_shape == (None, 7, 1, 9)
# check output dimensions
if isinstance(mpm.model.output_shape, tuple):
assert mpm.model.output_shape == (None, 4)
elif isinstance(mpm.model.output_shape, list):
for tail_shape in mpm.model.output_shape:
assert tail_shape == (None, 4)
else:
raise TypeError(f"Type of model.output_shape as to be a tuple (one tail)"
f" or a list of tuples (multiple tails). Received: {type(mpm.model.output_shape)}")
def test_set_loss(self, mpm):
assert callable(mpm.loss) or (len(mpm.loss) > 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment