import keras
import pytest

from src.model_modules.model_class import AbstractModelClass


class TestAbstractModelClass:

    @pytest.fixture
    def amc(self):
        return AbstractModelClass()

    def test_init(self, amc):
        assert amc.model is None
        assert amc.loss is None

    def test_model_property(self, amc):
        amc.model = keras.Model()
        assert isinstance(amc.model, keras.Model) is True

    def test_loss_property(self, amc):
        amc.loss = keras.losses.mean_absolute_error
        assert amc.loss == keras.losses.mean_absolute_error

    def test_getattr(self, amc):
        amc.model = keras.Model()
        assert hasattr(amc, "compile") is True
        assert hasattr(amc.model, "compile") is True
        assert amc.compile == amc.model.compile