import logging

import pytest

from src.data_handling.data_generator import DataGenerator
from src.datastore import NameNotFoundInScope
from src.helpers import PyTestRegex
from src.run_modules.experiment_setup import ExperimentSetup
from src.run_modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST
from src.run_modules.run_environment import RunEnvironment


class TestPreProcessing:

    @pytest.fixture
    def obj_super_init(self):
        obj = object.__new__(PreProcessing)
        super(PreProcessing, obj).__init__()
        obj.data_store.set("NAME1", 1, "general")
        obj.data_store.set("NAME2", 2, "general")
        obj.data_store.set("NAME3", 3, "general")
        obj.data_store.set("NAME1", 10, "general.sub")
        obj.data_store.set("NAME4", 4, "general.sub.sub")
        yield obj
        RunEnvironment().__del__()

    @pytest.fixture
    def obj_with_exp_setup(self):
        ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
                        var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background")
        pre = object.__new__(PreProcessing)
        super(PreProcessing, pre).__init__()
        yield pre
        RunEnvironment().__del__()

    def test_init(self, caplog):
        ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
                        var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
        caplog.set_level(logging.INFO)
        with PreProcessing():
            assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
            assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
            assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 '
                                                                        r'station\(s\). Found 5/5 valid stations.'))
        RunEnvironment().__del__()

    def test_run(self, obj_with_exp_setup):
        assert obj_with_exp_setup.data_store.search_name("generator") == []
        assert obj_with_exp_setup._run() is None
        assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val",
                                                                                 "general.train_val", "general.test"])

    def test_split_train_val_test(self, obj_with_exp_setup):
        assert obj_with_exp_setup.data_store.search_name("generator") == []
        obj_with_exp_setup.split_train_val_test()
        data_store = obj_with_exp_setup.data_store
        expected_params = ["generator", "start", "end", "stations", "permute_data"]
        assert data_store.search_scope("general.train") == sorted(expected_params)
        assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test",
                                                              "general.train_val"])

    def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup):
        caplog.set_level(logging.DEBUG)
        obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general.awesome")
        obj_with_exp_setup.create_set_split(slice(0, 2), "awesome")
        assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']")
        data_store = obj_with_exp_setup.data_store
        assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator)
        with pytest.raises(NameNotFoundInScope):
            data_store.get("generator", "general")
        assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBY081"]

    def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup):
        caplog.set_level(logging.DEBUG)
        obj_with_exp_setup.create_set_split(slice(0, 2), "awesome")
        assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', "
                                                       "'DEBW076', 'DEBW087', 'DEBW001']")
        data_store = obj_with_exp_setup.data_store
        assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator)
        with pytest.raises(NameNotFoundInScope):
            data_store.get("generator", "general")
        assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']

    def test_check_valid_stations(self, caplog, obj_with_exp_setup):
        pre = obj_with_exp_setup
        caplog.set_level(logging.INFO)
        args = pre.data_store.create_args_dict(DEFAULT_ARGS_LIST)
        kwargs = pre.data_store.create_args_dict(DEFAULT_KWARGS_LIST)
        stations = pre.data_store.get("stations", "general")
        valid_stations = pre.check_valid_stations(args, kwargs, stations)
        assert len(valid_stations) < len(stations)
        assert valid_stations == stations[:-1]
        assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
        assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 '
                                                                    r'station\(s\). Found 5/6 valid stations.'))

    def test_split_set_indices(self, obj_super_init):
        dummy_list = list(range(0, 15))
        train, val, test, train_val = obj_super_init.split_set_indices(len(dummy_list), 0.9)
        assert dummy_list[train] == list(range(0, 10))
        assert dummy_list[val] == list(range(10, 13))
        assert dummy_list[test] == list(range(13, 15))
        assert dummy_list[train_val] == list(range(0, 13))