import keras
import pytest

from mlair.model_modules.model_class import IntelliO3_ts_architecture


class TestIntelliO3_ts_architecture:

    @pytest.fixture
    def mpm(self):
        return IntelliO3_ts_architecture(input_shape=[(7, 1, 9)], output_shape=[(4,)])

    def test_init(self, mpm):
        # check if loss number of loss functions fit to model outputs
        #       same loss fkts. for all tails               or different fkts. per tail
        if isinstance(mpm.model.output_shape, list):
            assert (callable(mpm.compile_options["loss"]) or (len(mpm.compile_options["loss"]) == 1)) or (
                        len(mpm.compile_options["loss"]) == len(mpm.model.output_shape))
        elif isinstance(mpm.model.output_shape, tuple):
            assert callable(mpm.compile_options["loss"]) or (len(mpm.compile_options["loss"]) == 1)

    def test_set_model(self, mpm):
        assert isinstance(mpm.model, keras.Model)
        assert mpm.model.layers[0].output_shape == (None, 7, 1, 9)
        # check output dimensions
        if isinstance(mpm.model.output_shape, tuple):
            assert mpm.model.output_shape == (None, 4)
        elif isinstance(mpm.model.output_shape, list):
            for tail_shape in mpm.model.output_shape:
                assert tail_shape == (None, 4)
        else:
            raise TypeError(f"Type of model.output_shape as to be a tuple (one tail)"
                            f" or a list of tuples (multiple tails). Received: {type(mpm.model.output_shape)}")

    # def test_set_loss(self, mpm):
    #     assert callable(mpm.loss) or (len(mpm.loss) > 0)

    def test_set_compile_options(self, mpm):
        assert callable(mpm.compile_options["loss"]) or (len(mpm.compile_options["loss"]) > 0)