import keras
import numpy as np
import pytest
import mock
import os

from src.helpers import l_p_loss
from src.model_modules.keras_extensions import *


class TestHistoryAdvanced:

    def test_init(self):
        hist = HistoryAdvanced()
        assert hist.validation_data is None
        assert hist.model is None
        assert isinstance(hist.epoch, list) and len(hist.epoch) == 0
        assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0

    def test_on_train_begin(self):
        hist = HistoryAdvanced()
        hist.epoch = [1, 2, 3]
        hist.history = {"mse": [10, 7, 4]}
        hist.on_train_begin()
        assert hist.epoch == [1, 2, 3]
        assert hist.history == {"mse": [10, 7, 4]}


class TestLearningRateDecay:

    def test_init(self):
        lr_decay = LearningRateDecay()
        assert lr_decay.lr == {'lr': []}
        assert lr_decay.base_lr == 0.01
        assert lr_decay.drop == 0.96
        assert lr_decay.epochs_drop == 8

    def test_check_param(self):
        lr_decay = object.__new__(LearningRateDecay)
        assert lr_decay.check_param(1, "tester") == 1
        assert lr_decay.check_param(0.5, "tester") == 0.5
        with pytest.raises(ValueError) as e:
            lr_decay.check_param(0, "tester")
        assert "tester is out of allowed range (0, 1]: tester=0" in e.value.args[0]
        with pytest.raises(ValueError) as e:
            lr_decay.check_param(1.5, "tester")
        assert "tester is out of allowed range (0, 1]: tester=1.5" in e.value.args[0]
        assert lr_decay.check_param(1.5, "tester", upper=None) == 1.5
        with pytest.raises(ValueError) as e:
            lr_decay.check_param(0, "tester", upper=None)
        assert "tester is out of allowed range (0, inf): tester=0" in e.value.args[0]
        assert lr_decay.check_param(0.5, "tester", lower=None) == 0.5
        with pytest.raises(ValueError) as e:
            lr_decay.check_param(0.5, "tester", lower=None, upper=0.2)
        assert "tester is out of allowed range (-inf, 0.2]: tester=0.5" in e.value.args[0]
        assert lr_decay.check_param(10, "tester", upper=None, lower=None)

    def test_on_epoch_begin(self):
        lr_decay = LearningRateDecay(base_lr=0.02, drop=0.95, epochs_drop=2)
        model = keras.Sequential()
        model.add(keras.layers.Dense(1, input_dim=1))
        model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2))
        model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=5, callbacks=[lr_decay])
        assert lr_decay.lr['lr'] == [0.02, 0.02, 0.02 * 0.95, 0.02 * 0.95, 0.02 * 0.95 * 0.95]


class TestModelCheckpointAdvanced:

    @pytest.fixture()
    def callbacks(self):
        callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s")
        return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"},
                     {"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}]

    @pytest.fixture
    def ckpt(self, callbacks):
        ckpt_name = "ckpt.test"
        return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, verbose=1)

    def test_init(self, ckpt, callbacks):
        assert ckpt.callbacks == callbacks
        assert ckpt.monitor == "val_loss"
        assert ckpt.save_best_only is True
        assert ckpt.best == np.inf

    def test_update_best(self, ckpt):
        hist = HistoryAdvanced()
        hist.history["val_loss"] = [10, 6]
        ckpt.update_best(hist)
        assert ckpt.best == 6

    def test_update_callbacks(self, ckpt, callbacks):
        ckpt.update_callbacks(callbacks[0])
        assert ckpt.callbacks == [callbacks[0]]

    def test_on_epoch_end(self, ckpt):
        path = os.path.dirname(__file__)
        ckpt.set_model(mock.MagicMock())
        ckpt.best = 6
        ckpt.on_epoch_end(0, {"val_loss": 6})
        assert "callback_hist" not in os.listdir(path)
        ckpt.on_epoch_end(9, {"val_loss": 10})
        assert "callback_hist" not in os.listdir(path)
        ckpt.on_epoch_end(10, {"val_loss": 4})
        assert "callback_hist" in os.listdir(path)
        os.remove(os.path.join(path, "callback_hist"))
        os.remove(os.path.join(path, "callback_lr"))
        ckpt.save_best_only = False
        ckpt.on_epoch_end(10, {"val_loss": 3})
        assert "callback_hist" in os.listdir(path)
        os.remove(os.path.join(path, "callback_hist"))
        os.remove(os.path.join(path, "callback_lr"))