import logging
import pytest
import os
import keras

from src.modules.model_setup import ModelSetup
from src.modules.run_environment import RunEnvironment
from src.data_generator import DataGenerator


class TestModelSetup:

    @pytest.fixture
    def setup(self):
        obj = object.__new__(ModelSetup)
        super(ModelSetup, obj).__init__()
        obj.scope = "general.modeltest"
        obj.model = None
        yield obj
        RunEnvironment().__del__()

    @pytest.fixture
    def gen(self):
        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                             'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})

    @pytest.fixture
    def setup_with_gen(self, setup, gen):
        setup.data_store.put("generator", gen, "general.train")
        setup.data_store.put("window_history_size", gen.window_history_size, "general")
        setup.data_store.put("window_lead_time", gen.window_lead_time, "general")
        yield setup
        RunEnvironment().__del__()

    def test_set_checkpoint(self, setup):
        assert "general.modeltest" not in setup.data_store.search_name("checkpoint")
        setup.model_name = "TestName"
        setup._set_checkpoint()
        assert "general.modeltest" in setup.data_store.search_name("checkpoint")

    def test_my_model_settings(self, setup_with_gen):
        setup_with_gen.my_model_settings()
        expected = {"channels", "dropout_rate", "regularizer", "initial_lr", "optimizer", "lr_decay", "epochs",
                    "batch_size", "activation", "loss"}
        assert expected <= set(setup_with_gen.data_store.search_scope(setup_with_gen.scope, current_scope_only=True))

    def test_build_model(self, setup_with_gen):
        setup_with_gen.my_model_settings()
        assert setup_with_gen.model is None
        setup_with_gen.build_model()
        assert isinstance(setup_with_gen.model, keras.Model)

    def test_load_weights(self):
        pass

    def test_compile_model(self):
        pass