import keras
import pytest
from keras.callbacks import ModelCheckpoint, History
import mock
import os
import json
import shutil
import logging
import glob

from src.inception_model import InceptionModelBase
from src.flatten import flatten_tail
from src.modules.training import Training
from src.modules.run_environment import RunEnvironment
from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator
from src.helpers import LearningRateDecay, PyTestRegex


def my_test_model(activation, window_history_size, channels, dropout_rate, add_minor_branch=False):
    inception_model = InceptionModelBase()
    conv_settings_dict1 = {
        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
    if add_minor_branch:
        out = [flatten_tail(X_in, 'Minor_1', activation=activation)]
    else:
        out = []
    X_in = keras.layers.Dropout(dropout_rate)(X_in)
    out.append(flatten_tail(X_in, 'Main', activation=activation))
    return keras.Model(inputs=X_input, outputs=out)


class TestTraining:

    @pytest.fixture
    def init_without_run(self, path, model, checkpoint):
        obj = object.__new__(Training)
        super(Training, obj).__init__()
        obj.model = model
        obj.train_set = None
        obj.val_set = None
        obj.test_set = None
        obj.batch_size = 256
        obj.epochs = 2
        obj.checkpoint = checkpoint
        obj.lr_sc = LearningRateDecay()
        obj.experiment_name = "TestExperiment"
        obj.data_store.put("generator", mock.MagicMock(return_value="mock_train_gen"), "general.train")
        obj.data_store.put("generator", mock.MagicMock(return_value="mock_val_gen"), "general.val")
        obj.data_store.put("generator", mock.MagicMock(return_value="mock_test_gen"), "general.test")
        os.makedirs(path)
        obj.data_store.put("experiment_path", path, "general")
        obj.data_store.put("experiment_name", "TestExperiment", "general")
        path_plot = os.path.join(path, "plots")
        os.makedirs(path_plot)
        obj.data_store.put("plot_path", path_plot, "general")
        yield obj
        if os.path.exists(path):
            shutil.rmtree(path)
        RunEnvironment().__del__()

    @pytest.fixture
    def learning_rate(self):
        return {"lr": [0.01, 0.0094]}

    @pytest.fixture
    def init_with_lr(self, init_without_run, learning_rate):
        init_without_run.lr_sc.lr = learning_rate
        return init_without_run

    @pytest.fixture
    def history(self):
        h = History()
        h.epoch = [0, 1]
        h.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
                     'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
                     'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
                     'loss': [0.6795708956961347, 0.45963566494176616],
                     'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                     'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
        return h

    @pytest.fixture
    def path(self):
        return os.path.join(os.path.dirname(__file__), "TestExperiment")

    @pytest.fixture
    def generator(self, path):
        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE',
                             ['DEBW107'], ['o3', 'temp'], 'datetime', 'variables',
                             'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})

    @pytest.fixture
    def model(self):
        return my_test_model(keras.layers.PReLU, 7, 2, 0.1, False)

    @pytest.fixture
    def checkpoint(self, path):
        return ModelCheckpoint(os.path.join(path, "model_checkpoint"), monitor='val_loss', save_best_only=True)

    @pytest.fixture
    def ready_to_train(self, generator, init_without_run):
        init_without_run.train_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
        init_without_run.val_set = Distributor(generator, init_without_run.model, init_without_run.batch_size)
        init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
        return init_without_run

    @pytest.fixture
    def ready_to_run(self, generator, init_without_run):
        obj = init_without_run
        obj.data_store.put("generator", generator, "general.train")
        obj.data_store.put("generator", generator, "general.val")
        obj.data_store.put("generator", generator, "general.test")
        obj.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
        return obj

    @pytest.fixture
    def ready_to_init(self, generator, model, checkpoint, path):
        os.makedirs(path)
        obj = RunEnvironment()
        obj.data_store.put("generator", generator, "general.train")
        obj.data_store.put("generator", generator, "general.val")
        obj.data_store.put("generator", generator, "general.test")
        model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
        obj.data_store.put("model", model, "general.model")
        obj.data_store.put("batch_size", 256, "general.model")
        obj.data_store.put("epochs", 2, "general.model")
        obj.data_store.put("checkpoint", checkpoint, "general.model")
        obj.data_store.put("lr_decay", LearningRateDecay(), "general.model")
        obj.data_store.put("experiment_name", "TestExperiment", "general")
        obj.data_store.put("experiment_path", path, "general")
        path_plot = os.path.join(path, "plots")
        os.makedirs(path_plot)
        obj.data_store.put("plot_path", path_plot, "general")
        yield obj
        if os.path.exists(path):
            shutil.rmtree(path)

    def test_init(self, ready_to_init):
        assert isinstance(Training(), Training)  # just test, if nothing fails

    def test_run(self, ready_to_run):
        assert ready_to_run._run() is None  # just test, if nothing fails

    def test_make_predict_function(self, init_without_run):
        assert hasattr(init_without_run.model, "predict_function") is False
        init_without_run.make_predict_function()
        assert hasattr(init_without_run.model, "predict_function")

    def test_set_gen(self, init_without_run):
        assert init_without_run.train_set is None
        init_without_run._set_gen("train")
        assert isinstance(init_without_run.train_set, Distributor)
        assert init_without_run.train_set.generator.return_value == "mock_train_gen"

    def test_set_generators(self, init_without_run):
        sets = ["train", "val", "test"]
        assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
        init_without_run.set_generators()
        assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
        assert all([getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets])

    def test_train(self, ready_to_train, path):
        assert not hasattr(ready_to_train.model, "history")
        assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
        ready_to_train.train()
        assert list(ready_to_train.model.history.history.keys()) == ["val_loss", "loss"]
        assert ready_to_train.model.history.epoch == [0, 1]
        assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2

    def test_save_model(self, init_without_run, path, caplog):
        caplog.set_level(logging.DEBUG)
        model_name = "TestExperiment_my_model.h5"
        assert model_name not in os.listdir(path)
        init_without_run.save_model()
        assert caplog.record_tuples[0] == ("root", 10, PyTestRegex(f"save best model to {os.path.join(path, model_name)}"))
        assert model_name in os.listdir(path)

    def test_load_best_model_no_weights(self, init_without_run, caplog):
        caplog.set_level(logging.DEBUG)
        init_without_run.load_best_model("notExisting")
        assert caplog.record_tuples[0] == ("root", 10, PyTestRegex("load best model: notExisting"))
        assert caplog.record_tuples[1] == ("root", 20, PyTestRegex("no weights to reload..."))

    def test_save_callbacks_history_created(self, init_without_run, history, path):
        init_without_run.save_callbacks(history)
        assert "history.json" in os.listdir(path)

    def test_save_callbacks_lr_created(self, init_with_lr, history, path):
        init_with_lr.save_callbacks(history)
        assert "history_lr.json" in os.listdir(path)

    def test_save_callbacks_inspect_history(self, init_without_run, history, path):
        init_without_run.save_callbacks(history)
        with open(os.path.join(path, "history.json")) as jfile:
            hist = json.load(jfile)
            assert hist == history.history

    def test_save_callbacks_inspect_lr(self, init_with_lr, history, path):
        init_with_lr.save_callbacks(history)
        with open(os.path.join(path, "history_lr.json")) as jfile:
            lr = json.load(jfile)
            assert lr == init_with_lr.lr_sc.lr

    def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
        assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
        init_without_run.create_monitoring_plots(history, learning_rate)
        assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2