import os
import numpy as np
import shutil

import pytest

from mlair.data_handler import KerasIterator
from mlair.data_handler import DataCollection
from mlair.helpers.datastore import EmptyScope
from mlair.model_modules.keras_extensions import CallbackHandler
from mlair.model_modules.fully_connected_networks import FCN_64_32_16
from mlair.model_modules import AbstractModelClass
from mlair.run_modules.model_setup import ModelSetup
from mlair.run_modules.run_environment import RunEnvironment


class TestModelSetup:

    @pytest.fixture
    def setup(self):
        obj = object.__new__(ModelSetup)
        super(ModelSetup, obj).__init__()
        obj.scope = "general.model"
        obj.model = None
        obj.callbacks_name = "placeholder_%s_str.pickle"
        obj.data_store.set("model_class", FCN_64_32_16)
        obj.data_store.set("lr_decay", "dummy_str", "general.model")
        obj.data_store.set("hist", "dummy_str", "general.model")
        obj.data_store.set("epochs", 2)
        obj.model_name = "%s.h5"
        yield obj
        RunEnvironment().__del__()

    @pytest.fixture
    def path(self):
        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
        shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
        yield p
        shutil.rmtree(p, ignore_errors=True)

    @pytest.fixture
    def keras_iterator(self, path):
        coll = []
        for i in range(3):
            coll.append(DummyData(50 + i))
        data_coll = DataCollection(collection=coll)
        KerasIterator(data_coll, 25, path)
        return data_coll

    @pytest.fixture
    def setup_with_gen(self, setup, keras_iterator):
        setup.data_store.set("data_collection", keras_iterator, "train")
        input_shape = [keras_iterator[0].get_X()[0].shape[1:]]
        setup.data_store.set("input_shape", input_shape, "model")
        output_shape = [keras_iterator[0].get_Y()[0].shape[1:]]
        setup.data_store.set("output_shape", output_shape, "model")
        yield setup
        RunEnvironment().__del__()

    @pytest.fixture
    def setup_with_gen_tiny(self, setup, keras_iterator):
        setup.data_store.set("data_collection", keras_iterator, "train")
        yield setup
        RunEnvironment().__del__()

    @pytest.fixture
    def setup_with_model(self, setup):
        setup.model = AbstractModelClass(input_shape=(12, 1), output_shape=2)
        setup.model.test_param = "42"
        yield setup
        RunEnvironment().__del__()

    @staticmethod
    def current_scope_as_set(model_cls):
        return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))

    def test_set_callbacks(self, setup):
        assert "general.model" not in setup.data_store.search_name("callbacks")
        setup.checkpoint_name = "TestName"
        setup._set_callbacks()
        assert "general.model" in setup.data_store.search_name("callbacks")
        callbacks = setup.data_store.get("callbacks", "general.model")
        assert len(callbacks.get_callbacks()) == 3

    def test_set_callbacks_no_lr_decay(self, setup):
        setup.data_store.set("lr_decay", None, "general.model")
        assert "general.model" not in setup.data_store.search_name("callbacks")
        setup.checkpoint_name = "TestName"
        setup._set_callbacks()
        callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
        assert len(callbacks.get_callbacks()) == 2
        with pytest.raises(IndexError):
            callbacks.get_callback_by_name("lr_decay")

    def test_get_model_settings(self, setup_with_model):
        setup_with_model.scope = "model_test"
        with pytest.raises(EmptyScope):
            self.current_scope_as_set(setup_with_model)  # will fail because scope is not created
        setup_with_model.get_model_settings()  # this saves now the parameter test_param into scope
        assert {"test_param", "model_name"} <= self.current_scope_as_set(setup_with_model)

    def test_build_model(self, setup_with_gen):
        assert setup_with_gen.model is None
        setup_with_gen.build_model()
        assert isinstance(setup_with_gen.model, AbstractModelClass)
        expected = {"lr_decay", "model_name", "optimizer", "activation", "input_shape", "output_shape"}
        assert expected <= self.current_scope_as_set(setup_with_gen)

    def test_set_shapes(self, setup_with_gen_tiny):
        assert len(setup_with_gen_tiny.data_store.search_name("input_shape")) == 0
        assert len(setup_with_gen_tiny.data_store.search_name("output_shape")) == 0
        setup_with_gen_tiny._set_shapes()
        assert setup_with_gen_tiny.data_store.get("input_shape", setup_with_gen_tiny.scope) == [(14, 1, 5), (10, 1, 2),
                                                                                                 (1, 1, 2)]
        assert setup_with_gen_tiny.data_store.get("output_shape", setup_with_gen_tiny.scope) == [(5,), (3,)]

    def test_load_weights(self):
        pass

    def test_compile_model(self):
        pass

    def test_run(self):
        pass

    def test_init(self):
        pass


class DummyData:

    def __init__(self, number_of_samples=np.random.randint(100, 150)):
        self.number_of_samples = number_of_samples

    def get_X(self, upsampling=False, as_numpy=True):
        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 1, 5))  # samples, window, variables
        X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 1, 2))  # samples, window, variables
        X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 1, 2))  # samples, window, variables
        return [X1, X2, X3]

    def get_Y(self, upsampling=False, as_numpy=True):
        Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5))  # samples, window
        Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 3))  # samples, window
        return [Y1, Y2]