import tensorflow.keras as keras
import pytest

from mlair import AbstractModelClass


class Paddings:
    allowed_paddings = {"pad1": 34, "another_pad": True}


class AbstractModelSubClass(AbstractModelClass):

    def __init__(self):
        super().__init__(input_shape=(12, 1, 2), output_shape=3)
        self.test_attr = "testAttr"


class TestAbstractModelClass:

    @pytest.fixture
    def amc(self):
        return AbstractModelClass(input_shape=(14, 1, 2), output_shape=(3,))

    @pytest.fixture
    def amsc(self):
        return AbstractModelSubClass()

    def test_init(self, amc):
        assert amc.model is None
        # assert amc.loss is None
        assert amc.model_name == "AbstractModelClass"
        assert amc.custom_objects == {}
        assert amc._input_shape == (14, 1, 2)
        assert amc._output_shape == 3

    def test_model_property(self, amc):
        amc.model = keras.Model()
        assert isinstance(amc.model, keras.Model) is True

    # def test_loss_property(self, amc):
    #     amc.loss = keras.losses.mean_absolute_error
    #     assert amc.loss == keras.losses.mean_absolute_error

    def test_compile_options_setter_all_empty(self, amc):
        amc.compile_options = None
        assert amc.compile_options == {'optimizer': None,
                                       'loss': None,
                                       'metrics': None,
                                       'loss_weights': None,
                                       'sample_weight_mode': None,
                                       'weighted_metrics': None,
                                       'target_tensors': None
                                       }

# has to be disabled until AbstractModelClass.__compare_keras_optimizers(new_v_attr, new_v_dic) works again
#    def test_compile_options_setter_as_dict(self, amc):
#        amc.compile_options = {"optimizer": keras.optimizers.SGD(),
#                               "loss": keras.losses.mean_absolute_error,
#                               "metrics": ["mse", "mae"]}
#        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
#        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
#        assert amc.compile_options["metrics"] == ["mse", "mae"]
#        assert amc.compile_options["loss_weights"] is None
#        assert amc.compile_options["sample_weight_mode"] is None
#        assert amc.compile_options["target_tensors"] is None
#        assert amc.compile_options["weighted_metrics"] is None

    def test_compile_options_setter_as_attr(self, amc):
        amc.optimizer = keras.optimizers.SGD()
        amc.loss = keras.losses.mean_absolute_error
        amc.compile_options = None  # This line has to be called!
        # optimizer check
        assert isinstance(amc.optimizer, keras.optimizers.SGD)
        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
        # loss check
        assert amc.loss == keras.losses.mean_absolute_error
        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
        # check rest (all None as not set)
        assert amc.compile_options["metrics"] is None
        assert amc.compile_options["loss_weights"] is None
        assert amc.compile_options["sample_weight_mode"] is None
        assert amc.compile_options["target_tensors"] is None
        assert amc.compile_options["weighted_metrics"] is None

    def test_compile_options_setter_as_mix_attr_dict_no_duplicates(self, amc):
        amc.optimizer = keras.optimizers.SGD()
        amc.compile_options = {"loss": keras.losses.mean_absolute_error,
                               "loss_weights": [0.2, 0.8]}
        # check setting by attribute
        assert isinstance(amc.optimizer, keras.optimizers.SGD)
        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
        # check setting by dict
        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
        assert amc.compile_options["loss_weights"] == [0.2, 0.8]
        # check rest (all None as not set)
        assert amc.compile_options["metrics"] is None
        assert amc.compile_options["sample_weight_mode"] is None
        assert amc.compile_options["target_tensors"] is None
        assert amc.compile_options["weighted_metrics"] is None

# has to be disabled until AbstractModelClass.__compare_keras_optimizers(new_v_attr, new_v_dic) works again
#    def test_compile_options_setter_as_mix_attr_dict_valid_duplicates_optimizer(self, amc):
#        amc.optimizer = keras.optimizers.SGD()
#        amc.metrics = ['mse']
#        amc.compile_options = {"optimizer": keras.optimizers.SGD(),
#                               "loss": keras.losses.mean_absolute_error}
#        # check duplicate (attr and dic)
#        assert isinstance(amc.optimizer, keras.optimizers.SGD)
#        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
#        # check setting by dict
#        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
#        # check setting by attr
#        assert amc.metrics == ['mse']
#        assert amc.compile_options["metrics"] == ['mse']
#        # check rest (all None as not set)
#        assert amc.compile_options["loss_weights"] is None
#        assert amc.compile_options["sample_weight_mode"] is None
#        assert amc.compile_options["target_tensors"] is None
#        assert amc.compile_options["weighted_metrics"] is None

    def test_compile_options_setter_as_mix_attr_dict_valid_duplicates_none_optimizer(self, amc):
        amc.optimizer = keras.optimizers.SGD()
        amc.metrics = ['mse']
        amc.compile_options = {"metrics": ['mse'],
                               "loss": keras.losses.mean_absolute_error}
        # check duplicate (attr and dic)
        assert amc.metrics == ['mse']
        assert amc.compile_options["metrics"] == ['mse']
        # check setting by dict
        assert amc.compile_options["loss"] == keras.losses.mean_absolute_error
        # check setting by attr
        assert isinstance(amc.optimizer, keras.optimizers.SGD)
        assert isinstance(amc.compile_options["optimizer"], keras.optimizers.SGD)
        # check rest (all None as not set)
        assert amc.compile_options["loss_weights"] is None
        assert amc.compile_options["sample_weight_mode"] is None
        assert amc.compile_options["target_tensors"] is None
        assert amc.compile_options["weighted_metrics"] is None

    def test_compile_options_property_type_error(self, amc):
        with pytest.raises(TypeError) as einfo:
            amc.compile_options = 'hello world'
        assert "`compile_options' must be `dict' or `None', but is <class 'str'>." in str(einfo.value)

    def test_compile_options_setter_as_mix_attr_dict_invalid_duplicates_other_optimizer(self, amc):
        amc.optimizer = keras.optimizers.SGD()
        with pytest.raises(ValueError) as einfo:
            amc.compile_options = {"optimizer": keras.optimizers.Adam()}
        assert "Got different values or arguments for same argument: self.optimizer=<class" \
               " 'tensorflow.python.keras.optimizer_v2.gradient_descent.SGD'> and " \
               "'optimizer': <class 'tensorflow.python.keras.optimizer_v2.adam.Adam'>" in str(einfo.value)

    def test_compile_options_setter_as_mix_attr_dict_invalid_duplicates_same_optimizer_other_args(self, amc):
        amc.optimizer = keras.optimizers.SGD(lr=0.1)
        with pytest.raises(ValueError) as einfo:
            amc.compile_options = {"optimizer": keras.optimizers.SGD(lr=0.001)}
        assert "Got different values or arguments for same argument: self.optimizer=<class" \
               " 'tensorflow.python.keras.optimizer_v2.gradient_descent.SGD'> and " \
               "'optimizer': <class 'tensorflow.python.keras.optimizer_v2.gradient_descent.SGD'>" in str(einfo.value)

    def test_compile_options_setter_as_dict_invalid_keys(self, amc):
        with pytest.raises(ValueError) as einfo:
            amc.compile_options = {"optimizer": keras.optimizers.SGD(), "InvalidKeyword": [1, 2, 3]}
        assert "Got invalid key for compile_options. dict_keys(['optimizer', 'InvalidKeyword'])" in str(einfo.value)

#    def test_compare_keras_optimizers_equal(self, amc):
#        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(), keras.optimizers.SGD()) is True
#
#    def test_compare_keras_optimizers_no_optimizer(self, amc):
#        assert amc._AbstractModelClass__compare_keras_optimizers('NoOptimizer', keras.optimizers.SGD()) is False
#
#    def test_compare_keras_optimizers_other_parameters_run_sess(self, amc):
#        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(lr=0.1),
#                                                                 keras.optimizers.SGD(lr=0.01)) is False
#
#    def test_compare_keras_optimizers_other_parameters_none_sess(self, amc):
#        assert amc._AbstractModelClass__compare_keras_optimizers(keras.optimizers.SGD(decay=1),
#                                                                 keras.optimizers.SGD(decay=0.01)) is False

    def test_getattr(self, amc):
        amc.model = keras.Model()
        assert hasattr(amc, "compile") is True
        assert hasattr(amc.model, "compile") is True
        assert amc.compile == amc.model.compile

    def test_get_settings(self, amc, amsc):
        assert amc.get_settings() == {"model_name": "AbstractModelClass", "_input_shape": (14, 1, 2),
                                      "_output_shape": 3}
        assert amsc.get_settings() == {"test_attr": "testAttr", "model_name": "AbstractModelSubClass",
                                       "_input_shape": (12, 1, 2), "_output_shape": 3}

    def test_custom_objects(self, amc):
        amc.custom_objects = {"Test": 123}
        assert amc.custom_objects == {"Test": 123}

    def test_set_custom_objects(self, amc):
        amc.set_custom_objects(Test=22, minor_param="minor")
        assert amc.custom_objects == {"Test": 22, "minor_param": "minor"}
        amc.set_custom_objects(Test=2, minor_param1="minor1")
        assert amc.custom_objects == {"Test": 2, "minor_param1": "minor1"}
        paddings = Paddings()
        amc.set_custom_objects(Test=1, Padding2D=paddings)
        assert amc.custom_objects == {"Test": 1, "Padding2D": paddings, "pad1": 34, "another_pad": True}