import os

import pytest

from src.data_handling.data_generator import DataGenerator
from src.helpers.datastore import EmptyScope
from src.model_modules.keras_extensions import CallbackHandler
from src.model_modules.model_class import AbstractModelClass
from src.run_modules.model_setup import ModelSetup
from src.run_modules.run_environment import RunEnvironment


class TestModelSetup:

    @pytest.fixture
    def setup(self):
        obj = object.__new__(ModelSetup)
        super(ModelSetup, obj).__init__()
        obj.scope = "general.model"
        obj.model = None
        obj.callbacks_name = "placeholder_%s_str.pickle"
        obj.data_store.set("lr_decay", "dummy_str", "general.model")
        obj.data_store.set("hist", "dummy_str", "general.model")
        obj.model_name = "%s.h5"
        yield obj
        RunEnvironment().__del__()

    @pytest.fixture
    def gen(self):
        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})

    @pytest.fixture
    def setup_with_gen(self, setup, gen):
        setup.data_store.set("generator", gen, "general.train")
        setup.data_store.set("window_history_size", gen.window_history_size, "general")
        setup.data_store.set("window_lead_time", gen.window_lead_time, "general")
        setup.data_store.set("channels", 2, "general")
        yield setup
        RunEnvironment().__del__()

    @pytest.fixture
    def setup_with_gen_tiny(self, setup, gen):
        setup.data_store.set("generator", gen, "general.train")
        yield setup
        RunEnvironment().__del__()

    @pytest.fixture
    def setup_with_model(self, setup):
        setup.model = AbstractModelClass()
        setup.model.epochs = 2
        setup.model.batch_size = int(256)
        yield setup
        RunEnvironment().__del__()

    @staticmethod
    def current_scope_as_set(model_cls):
        return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True))

    def test_set_callbacks(self, setup):
        assert "general.model" not in setup.data_store.search_name("callbacks")
        setup.checkpoint_name = "TestName"
        setup._set_callbacks()
        assert "general.model" in setup.data_store.search_name("callbacks")
        callbacks = setup.data_store.get("callbacks", "general.model")
        assert len(callbacks.get_callbacks()) == 3

    def test_set_callbacks_no_lr_decay(self, setup):
        setup.data_store.set("lr_decay", None, "general.model")
        assert "general.model" not in setup.data_store.search_name("callbacks")
        setup.checkpoint_name = "TestName"
        setup._set_callbacks()
        callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model")
        assert len(callbacks.get_callbacks()) == 2
        with pytest.raises(IndexError):
            callbacks.get_callback_by_name("lr_decay")

    def test_get_model_settings(self, setup_with_model):
        setup_with_model.scope = "model_test"
        with pytest.raises(EmptyScope):
            self.current_scope_as_set(setup_with_model)  # will fail because scope is not created
        setup_with_model.get_model_settings()  # this saves now the parameters epochs and batch_size into scope
        assert {"epochs", "batch_size"} <= self.current_scope_as_set(setup_with_model)

    def test_build_model(self, setup_with_gen):
        assert setup_with_gen.model is None
        setup_with_gen.build_model()
        assert isinstance(setup_with_gen.model, AbstractModelClass)
        expected = {"window_history_size", "window_lead_time", "channels", "dropout_rate", "regularizer", "initial_lr",
                    "optimizer", "epochs", "batch_size", "activation"}
        assert expected <= self.current_scope_as_set(setup_with_gen)

    def test_set_channels(self, setup_with_gen_tiny):
        assert len(setup_with_gen_tiny.data_store.search_name("channels")) == 0
        setup_with_gen_tiny._set_channels()
        assert setup_with_gen_tiny.data_store.get("channels", setup_with_gen_tiny.scope) == 2

    def test_load_weights(self):
        pass

    def test_compile_model(self):
        pass

    def test_run(self):
        pass

    def test_init(self):
        pass