import copy
import glob
import json
import time

import logging
import os
import shutil
from typing import Callable

import tensorflow.keras as keras
import mock
import pytest
from tensorflow.keras.callbacks import History

from mlair.data_handler import DataCollection, KerasIterator, DefaultDataHandler
from mlair.helpers import PyTestRegex
from mlair.model_modules.fully_connected_networks import FCN
from mlair.model_modules.flatten import flatten_tail
from mlair.model_modules.inception_model import InceptionModelBase
from mlair.model_modules.keras_extensions import LearningRateDecay, HistoryAdvanced, CallbackHandler, EpoTimingCallback
from mlair.run_modules.run_environment import RunEnvironment
from mlair.run_modules.training import Training


def my_test_model(activation, window_history_size, channels, output_size, dropout_rate, add_minor_branch=False):
    inception_model = InceptionModelBase()
    conv_settings_dict1 = {
        'tower_1': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (3, 1), 'activation': activation},
        'tower_2': {'reduction_filter': 8, 'tower_filter': 8 * 2, 'tower_kernel': (5, 1), 'activation': activation}, }
    pool_settings_dict1 = {'pool_kernel': (3, 1), 'tower_filter': 8 * 2, 'activation': activation}
    X_input = keras.layers.Input(shape=(window_history_size + 1, 1, channels))
    X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1)
    if add_minor_branch:
        out = [flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=4,
                            output_activation='linear', reduction_filter=64,
                            name='Minor_1', dropout_rate=dropout_rate,
                            )]
    else:
        out = []
    X_in = keras.layers.Dropout(dropout_rate)(X_in)
    out.append(flatten_tail(X_in, inner_neurons=64, activation=activation, output_neurons=output_size,
                            output_activation='linear', reduction_filter=64,
                            name='Main', dropout_rate=dropout_rate,
                            ))
    return keras.Model(inputs=X_input, outputs=out)


class TestTraining:

    @pytest.fixture
    def init_without_run(self, path: str, model: keras.Model, callbacks: CallbackHandler, model_path, batch_path):
        obj = object.__new__(Training)
        super(Training, obj).__init__()
        obj.model = model
        obj.train_set = None
        obj.val_set = None
        obj.test_set = None
        obj.batch_size = 256
        obj.epochs = 2
        clbk, hist, lr = callbacks
        obj.callbacks = clbk
        obj.lr_sc = lr
        obj.hist = hist
        obj.experiment_name = "TestExperiment"
        obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_train_gen"), "general.train")
        obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_val_gen"), "general.val")
        obj.data_store.set("data_collection", mock.MagicMock(return_value="mock_test_gen"), "general.test")
        if not os.path.exists(path):
            os.makedirs(path)
        obj.data_store.set("experiment_path", path, "general")
        os.makedirs(batch_path)
        obj.data_store.set("batch_path", batch_path, "general")
        os.makedirs(model_path)
        obj.data_store.set("model_path", model_path, "general")
        obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
        obj.data_store.set("experiment_name", "TestExperiment", "general")

        path_plot = os.path.join(path, "plots")
        os.makedirs(path_plot)
        obj.data_store.set("plot_path", path_plot, "general")
        obj._train_model = True
        obj._create_new_model = False
        try:
            yield obj
        finally:
            if os.path.exists(path):
                shutil.rmtree(path)
            try:
                RunEnvironment().__del__()
            except AssertionError:
                pass
        # try:
        #     yield obj
        # finally:
        #     if os.path.exists(path):
        #         shutil.rmtree(path)
        #     try:
        #         RunEnvironment().__del__()
        #     except AssertionError:
        #         pass

    @pytest.fixture
    def learning_rate(self):
        lr = LearningRateDecay()
        lr.lr = {"lr": [0.01, 0.0094]}
        return lr

    @pytest.fixture
    def history(self):
        h = History()
        h.epoch = [0, 1]
        h.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
                     'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
                     'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
                     'loss': [0.6795708956961347, 0.45963566494176616],
                     'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                     'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
        h.model = mock.MagicMock()
        return h

    @pytest.fixture
    def epo_timing(self):
        epo_timing = EpoTimingCallback()
        epo_timing.epoch = [0, 1]
        epo_timing.epo_timing = {"epo_timing": [0.1, 0.2]}

    @pytest.fixture
    def path(self):
        return os.path.join(os.path.dirname(__file__), "TestExperiment")

    @pytest.fixture
    def model_path(self, path):
        return os.path.join(path, "model")

    @pytest.fixture
    def batch_path(self, path):
        return os.path.join(path, "batch")

    @pytest.fixture
    def window_history_size(self):
        return 7

    @pytest.fixture
    def window_lead_time(self):
        return 2

    @pytest.fixture
    def statistics_per_var(self):
        return {'o3': 'dma8eu', 'temp': 'maximum'}

    @pytest.fixture
    def data_collection(self, path, window_history_size, window_lead_time, statistics_per_var):
        data_prep = DefaultDataHandler.build(['DEBW107'], data_path=os.path.join(path, 'data'),
                                             experiment_path=os.path.join(path, 'exp_path'),
                                             statistics_per_var=statistics_per_var, station_type="background",
                                             network="AIRBASE", sampling="daily", target_dim="variables",
                                             target_var="o3", time_dim="datetime",
                                             window_history_size=window_history_size,
                                             window_lead_time=window_lead_time, name_affix="train")
        return DataCollection([data_prep])

    @pytest.fixture
    def model(self, window_history_size, window_lead_time, statistics_per_var):
        channels = len(list(statistics_per_var.keys()))
        return FCN([(window_history_size + 1, 1, channels)], [window_lead_time])

    @pytest.fixture
    def callbacks(self, path):
        clbk = CallbackHandler()
        hist = HistoryAdvanced()
        epo_timing = EpoTimingCallback()
        clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
        lr = LearningRateDecay()
        clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
        clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
        clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
                                     save_best_only=True)
        return clbk, hist, lr

    @pytest.fixture
    def ready_to_train(self, data_collection: DataCollection, init_without_run: Training, batch_path: str):
        batch_size = init_without_run.batch_size
        model = init_without_run.model
        init_without_run.train_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="train")
        init_without_run.val_set = KerasIterator(data_collection, batch_size, batch_path, model=model, name="val")
        init_without_run.model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
        return init_without_run

    @pytest.fixture
    def ready_to_run(self, data_collection, init_without_run):
        obj = init_without_run
        obj.data_store.set("data_collection", data_collection, "general.train")
        obj.data_store.set("data_collection", data_collection, "general.val")
        obj.data_store.set("data_collection", data_collection, "general.test")
        obj.model.compile(**obj.model.compile_options)
        keras.utils.get_custom_objects().update(obj.model.custom_objects)
        return obj

    @pytest.fixture
    def ready_to_init(self, data_collection, model, callbacks, path, model_path, batch_path):
        if not os.path.exists(path):
            os.makedirs(path)
        os.makedirs(model_path)
        obj = RunEnvironment()
        obj.data_store.set("data_collection", data_collection, "general.train")
        obj.data_store.set("data_collection", data_collection, "general.val")
        obj.data_store.set("data_collection", data_collection, "general.test")
        model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.mean_absolute_error)
        obj.data_store.set("model", model, "general.model")
        obj.data_store.set("model_path", model_path, "general")
        obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
        obj.data_store.set("batch_size", 256, "general")
        obj.data_store.set("epochs", 2, "general")
        clbk, hist, lr = callbacks
        obj.data_store.set("callbacks", clbk, "general.model")
        obj.data_store.set("lr_decay", lr, "general.model")
        obj.data_store.set("hist", hist, "general.model")
        obj.data_store.set("experiment_name", "TestExperiment", "general")
        obj.data_store.set("experiment_path", path, "general")
        obj.data_store.set("train_model", True, "general")
        obj.data_store.set("create_new_model", True, "general")
        os.makedirs(batch_path)
        obj.data_store.set("batch_path", batch_path, "general")
        path_plot = os.path.join(path, "plots")
        os.makedirs(path_plot)
        obj.data_store.set("plot_path", path_plot, "general")
        yield obj
        if os.path.exists(path):
            shutil.rmtree(path)

    @staticmethod
    def create_training_obj(epochs, path, data_collection, batch_path, model_path,
                            statistics_per_var, window_history_size, window_lead_time) -> Training:

        channels = len(list(statistics_per_var.keys()))
        model = FCN([(window_history_size + 1, 1, channels)], [window_lead_time])

        obj = object.__new__(Training)
        super(Training, obj).__init__()
        obj.model = model
        obj.train_set = None
        obj.val_set = None
        obj.test_set = None
        obj.batch_size = 256
        obj.epochs = epochs

        clbk = CallbackHandler()
        hist = HistoryAdvanced()
        epo_timing = EpoTimingCallback()
        clbk.add_callback(hist, os.path.join(path, "hist_checkpoint.pickle"), "hist")
        lr = LearningRateDecay()
        clbk.add_callback(lr, os.path.join(path, "lr_checkpoint.pickle"), "lr")
        clbk.add_callback(epo_timing, os.path.join(path, "epo_timing.pickle"), "epo_timing")
        clbk.create_model_checkpoint(filepath=os.path.join(path, "model_checkpoint"), monitor='val_loss',
                                     save_best_only=True)
        obj.callbacks = clbk
        obj.lr_sc = lr
        obj.hist = hist
        obj.experiment_name = "TestExperiment"
        obj.data_store.set("data_collection", data_collection, "general.train")
        obj.data_store.set("data_collection", data_collection, "general.val")
        obj.data_store.set("data_collection", data_collection, "general.test")
        if not os.path.exists(path):
            os.makedirs(path)
        obj.data_store.set("experiment_path", path, "general")
        os.makedirs(batch_path, exist_ok=True)
        obj.data_store.set("batch_path", batch_path, "general")
        os.makedirs(model_path, exist_ok=True)
        obj.data_store.set("model_path", model_path, "general")
        obj.data_store.set("model_name", os.path.join(model_path, "test_model.h5"), "general.model")
        obj.data_store.set("experiment_name", "TestExperiment", "general")

        path_plot = os.path.join(path, "plots")
        os.makedirs(path_plot, exist_ok=True)
        obj.data_store.set("plot_path", path_plot, "general")
        obj._train_model = True
        obj._create_new_model = False

        obj.model.compile(**obj.model.compile_options)
        return obj

    def test_init(self, ready_to_init):
        assert isinstance(Training(), Training)  # just test, if nothing fails

    def test_no_training(self, ready_to_init, caplog):
        caplog.set_level(logging.INFO)
        ready_to_init.data_store.set("train_model", False)
        Training()
        message = "No training has started, because train_model parameter was false."
        assert caplog.record_tuples[-2] == ("root", 20, message)

    def test_run(self, ready_to_run):
        assert ready_to_run._run() is None  # just test, if nothing fails

    def test_make_predict_function(self, init_without_run):
        assert hasattr(init_without_run.model, "predict_function") is True
        assert init_without_run.model.predict_function is None
        init_without_run.make_predict_function()
        assert isinstance(init_without_run.model.predict_function, Callable)

    def test_set_gen(self, init_without_run):
        assert init_without_run.train_set is None
        init_without_run._set_gen("train")
        assert isinstance(init_without_run.train_set, KerasIterator)
        assert init_without_run.train_set._collection.return_value == "mock_train_gen"

    def test_set_generators(self, init_without_run):
        sets = ["train", "val"]
        assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
        init_without_run.set_generators()
        assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets])
        assert all(
            [getattr(init_without_run, f"{obj}_set")._collection.return_value == f"mock_{obj}_gen" for obj in sets])

    def test_train(self, ready_to_train, path):
        assert ready_to_train.model.history is None
        assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
        ready_to_train.train()
        assert sorted(list(ready_to_train.model.history.history.keys())) == ["loss", "val_loss"]
        assert ready_to_train.model.history.epoch == [0, 1]
        assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2

    def test_save_model(self, init_without_run, model_path, caplog):
        caplog.set_level(logging.DEBUG)
        model_name = "test_model.h5"
        assert model_name not in os.listdir(model_path)
        init_without_run.save_model()
        message = PyTestRegex(f"save model to {os.path.join(model_path, model_name)}")
        assert caplog.record_tuples[1] == ("root", 10, message)
        assert model_name in os.listdir(model_path)

    def test_save_callbacks_history_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
        init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
        assert "history.json" in os.listdir(model_path)

    def test_save_callbacks_lr_created(self, init_without_run, history, learning_rate, epo_timing, model_path):
        init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
        assert "history_lr.json" in os.listdir(model_path)

    def test_save_callbacks_inspect_history(self, init_without_run, history, learning_rate, epo_timing, model_path):
        init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
        with open(os.path.join(model_path, "history.json")) as jfile:
            hist = json.load(jfile)
            assert hist == history.history

    def test_save_callbacks_inspect_lr(self, init_without_run, history, learning_rate, epo_timing, model_path):
        init_without_run.save_callbacks_as_json(history, learning_rate, epo_timing)
        with open(os.path.join(model_path, "history_lr.json")) as jfile:
            lr = json.load(jfile)
            assert lr == learning_rate.lr

    def test_create_monitoring_plots(self, init_without_run, learning_rate, history, path):
        assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 0
        history.model.output_names = mock.MagicMock(return_value=["Main"])
        history.model.metrics_names = mock.MagicMock(return_value=["loss", "mean_squared_error"])
        init_without_run.create_monitoring_plots(history, learning_rate, epoch_best=1)
        assert len(glob.glob(os.path.join(path, "plots", "TestExperiment_history_*.pdf"))) == 2

    def test_resume_training1(self, path: str, model_path, batch_path, data_collection, statistics_per_var,
                              window_history_size, window_lead_time):

        obj_1st = self.create_training_obj(4, path, data_collection, batch_path, model_path, statistics_per_var,
                                           window_history_size, window_lead_time)
        keras.utils.get_custom_objects().update(obj_1st.model.custom_objects)
        assert obj_1st._run() is None
        obj_2nd = self.create_training_obj(8, path, data_collection, batch_path, model_path, statistics_per_var,
                                           window_history_size, window_lead_time)
        assert obj_2nd._run() is None