import keras
import pytest

from src.model_modules.model_class import AbstractModelClass
from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel


class Paddings:
    allowed_paddings = {"pad1": 34, "another_pad": True}


class AbstractModelSubClass(AbstractModelClass):

    def __init__(self):
        super().__init__()
        self.test_attr = "testAttr"


class TestAbstractModelClass:

    @pytest.fixture
    def amc(self):
        return AbstractModelClass()

    @pytest.fixture
    def amsc(self):
        return AbstractModelSubClass()

    def test_init(self, amc):
        assert amc.model is None
        assert amc.loss is None
        assert amc.model_name == "AbstractModelClass"
        assert amc.custom_objects == {}

    def test_model_property(self, amc):
        amc.model = keras.Model()
        assert isinstance(amc.model, keras.Model) is True

    def test_loss_property(self, amc):
        amc.loss = keras.losses.mean_absolute_error
        assert amc.loss == keras.losses.mean_absolute_error

    def test_compile_options_property(self, amc):
        amc.compile_options = {"metrics": ["mse", "mae"]}
        assert amc.compile_options == {'loss_weights': None, 'metrics': ['mse', 'mae'], 'sample_weight_mode': None,
                                       'target_tensors': None, 'weighted_metrics': None}

    def test_compile_options_property_type_error(self, amc):
        with pytest.raises(TypeError) as einfo:
            amc.compile_options = 'hello world'
        assert "`value' has to be a dictionary. But it is <class 'str'>" in str(einfo.value)

    def test_getattr(self, amc):
        amc.model = keras.Model()
        assert hasattr(amc, "compile") is True
        assert hasattr(amc.model, "compile") is True
        assert amc.compile == amc.model.compile

    def test_get_settings(self, amc, amsc):
        assert amc.get_settings() == {"model_name": "AbstractModelClass"}
        assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass"}

    def test_custom_objects(self, amc):
        amc.custom_objects = {"Test": 123}
        assert amc.custom_objects == {"Test": 123}

    def test_set_custom_objects(self, amc):
        amc.set_custom_objects(Test=22, minor_param="minor")
        assert amc.custom_objects == {"Test": 22, "minor_param": "minor"}
        amc.set_custom_objects(Test=2, minor_param1="minor1")
        assert amc.custom_objects == {"Test": 2, "minor_param1": "minor1"}
        paddings = Paddings()
        amc.set_custom_objects(Test=1, Padding2D=paddings)
        assert amc.custom_objects == {"Test": 1, "Padding2D": paddings, "pad1": 34, "another_pad": True}


class TestMyPaperModel:

    @pytest.fixture
    def mpm(self):
        return MyPaperModel(window_history_size=6, window_lead_time=4, channels=9)

    def test_init(self, mpm):
        # check if loss number of loss functions fit to model outputs
        #       same loss fkts. for all tails               or different fkts. per tail
        if isinstance(mpm.model.output_shape, list):
            assert (callable(mpm.loss) or (len(mpm.loss) == 1)) or (len(mpm.loss) == len(mpm.model.output_shape))
        elif isinstance(mpm.model.output_shape, tuple):
            assert callable(mpm.loss) or (len(mpm.loss) == 1)

    def test_set_model(self, mpm):
        assert isinstance(mpm.model, keras.Model)
        assert mpm.model.layers[0].output_shape == (None, 7, 1, 9)
        # check output dimensions
        if isinstance(mpm.model.output_shape, tuple):
            assert mpm.model.output_shape == (None, 4)
        elif isinstance(mpm.model.output_shape, list):
            for tail_shape in mpm.model.output_shape:
                assert tail_shape == (None, 4)
        else:
            raise TypeError(f"Type of model.output_shape as to be a tuple (one tail)"
                            f" or a list of tuples (multiple tails). Received: {type(mpm.model.output_shape)}")

    def test_set_loss(self, mpm):
        assert callable(mpm.loss) or (len(mpm.loss) > 0)