import pytest
import os
from src.data_generator import DataGenerator
import logging
import numpy as np
import xarray as xr
import datetime as dt
import pandas as pd
from operator import itemgetter


class TestDataGenerator:

    @pytest.fixture
    def gen(self):
        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'UBA', 'DEBW107', ['o3', 'temp'],
                             'datetime', 'variables', 'o3')

    def test_init(self, gen):
        assert gen.path == os.path.join(os.path.dirname(__file__), 'data')
        assert gen.network == 'UBA'
        assert gen.stations == ['DEBW107']
        assert gen.variables == ['o3', 'temp']
        assert gen.interpolate_dim == 'datetime'
        assert gen.target_dim == 'variables'
        assert gen.target_var == 'o3'
        assert gen.interpolate_method == "linear"
        assert gen.limit_nan_fill == 1
        assert gen.window_history == 7
        assert gen.window_lead_time == 4
        assert gen.transform_method == "standardise"
        assert gen.kwargs == {}
        assert gen.threshold is not None

    def test_repr(self, gen):
        path = os.path.join(os.path.dirname(__file__), 'data')
        assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
                                          f"variables=['o3', 'temp'], interpolate_dim='datetime', " \
                                          f"target_dim='variables', target_var='o3', **{{}})".rstrip()

    def test_len(self, gen):
        assert len(gen) == 1
        gen.stations = ['station1', 'station2', 'station3']
        assert len(gen) == 3

    def test_iter(self, gen):
        assert hasattr(gen, '_iterator') is False
        iter(gen)
        assert hasattr(gen, '_iterator')
        assert gen._iterator == 0

    def test_next(self, gen):
        gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
        for i, d in enumerate(gen, start=1):
            assert i == gen._iterator

    def test_getitem(self, gen):
        gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
        station = gen["DEBW107"]
        assert len(station) == 2
        assert station[0].Stations.data == "DEBW107"
        assert station[0].data.shape[1:] == (8, 1, 2)
        assert station[1].data.shape[-1] == gen.window_lead_time
        assert station[0].data.shape[1] == gen.window_history + 1

    def test_threshold_setup(self, gen):
        def res(arg, val):
            gen.kwargs[arg] = val
            return list(map(float, gen.threshold_setup()))
        compare = np.testing.assert_array_almost_equal
        assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None
        assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None
        assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None
        assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None

    def test_get_key_representation(self, gen):
        gen.stations.append("DEBW108")
        f = gen.get_station_key
        iter(gen)
        assert f(None) == "DEBW107"
        with pytest.raises(KeyError) as e:
            f([None, None])
        assert "More than one key was given: [None, None]" in e.value.args[0]
        assert f(1) == "DEBW108"
        assert f([1]) == "DEBW108"
        with pytest.raises(KeyError) as e:
            f(3)
        assert "3 is not in range(0, 2)" in e.value.args[0]
        assert f("DEBW107") == "DEBW107"
        assert f(["DEBW108"]) == "DEBW108"
        with pytest.raises(KeyError) as e:
            f("DEBW999")
        assert "DEBW999 is not in stations" in e.value.args[0]
        with pytest.raises(KeyError) as e:
            f(6.5)
        assert "key has to be from Union[str, int]. Given was 6.5 (float)"