from src.data_handling.bootstraps import BootStraps, BootStrapGenerator
from src.data_handling.data_generator import DataGenerator

import pytest
import os

import numpy as np
import xarray as xr


class TestBootstraps:

    @pytest.fixture
    def path(self):
        path = os.path.join(os.path.dirname(__file__), "data")
        if not os.path.exists(path):
            os.makedirs(path)
        return path

    @pytest.fixture
    def boot_no_init(self, path):
        obj = object.__new__(BootStraps)
        super(BootStraps, obj).__init__()
        obj.number_bootstraps = 50
        obj.bootstrap_path = path
        return obj

    def test_valid_bootstrap_file(self, path, boot_no_init):
        station = "TESTSTATION"
        variables = "var1_var2_var3"
        window = 5
        # empty case
        assert len(os.listdir(path)) == 0
        assert boot_no_init.valid_bootstrap_file(station, variables, window) == (False, 50)
        # different cases, where files with bigger range are existing
        os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
        assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None)
        os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots100_shuffled.dat"))
        assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None)
        os.mknod(os.path.join(path, f"{station}_{variables}_hist10_nboots50_shuffled.dat"))
        os.mknod(os.path.join(path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat"))
        assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None)
        #  need to reload data and therefore remove not fitting files for this station
        assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 100)
        assert len(os.listdir(path)) == 1
        # reload because expanded boot number
        os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
        boot_no_init.number_bootstraps = 60
        assert boot_no_init.valid_bootstrap_file(station, variables, window) == (False, 60)
        assert len(os.listdir(path)) == 1
        # reload because of expanded window size, but use maximum boot number from file names
        os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots60_shuffled.dat"))
        boot_no_init.number_bootstraps = 50
        assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 60)

    def test_shuffle_single_variale(self, boot_no_init):
        data = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
        res = boot_no_init.shuffle_single_variable(data, chunks=(2, 3)).compute()
        assert res.shape == data.shape
        assert res.max() == data.max()
        assert res.min() == data.min()
        assert set(np.unique(res)).issubset({1, 2, 3})

    def test_create_shuffled_data(self):
        pass


class TestBootstrapGenerator:

    @pytest.fixture
    def orig_generator(self):
        return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
                             ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014)

    @pytest.fixture
    def boot_gen(self, orig_generator):
        path = os.path.join(os.path.dirname(__file__), 'data')
        dummy_content = xr.DataArray([1, 2, 3], dims="dummy")
        dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_shuffled.nc"))
        dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_shuffled.nc"))
        return BootStrapGenerator(orig_generator, 20, path)

    def test_init(self, orig_generator):
        gen = BootStrapGenerator(orig_generator, 20, os.path.join(os.path.dirname(__file__), 'data'))
        assert gen.stations == ["DEBW107", "DEBW013"]
        assert gen.variables == ["o3", "temp"]
        assert gen.number_of_boots == 20
        assert gen.bootstrap_path == os.path.join(os.path.dirname(__file__), 'data')

    def test_len(self, boot_gen):
        assert len(boot_gen) == 80

    def test_get_generator(self, boot_gen):
        pass

    def test_get_bootstrap_meta(self, boot_gen):
        pass

    def test_get_labels(self, boot_gen):
        pass

    def test_get_orig_prediction(self, boot_gen):
        pass

    def test_load_shuffled_data(self, boot_gen):
        shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"])
        assert isinstance(shuffled_data, xr.DataArray)
        assert all(shuffled_data.compute().values == [1, 2, 3])

    def test_create_file_regex(self, boot_gen):
        regex = boot_gen.create_file_regex("DEBW108", ["o3", "temp", "h2o"])
        test_list = ["DEBW108_o3_test23_test_shuffled.nc",
                     "DEBW107_o3_test23_test_shuffled.nc",
                     "DEBW108_o3_test23_test.nc",
                     "DEBW108_h2o_o3_temp_test_shuffled.nc",
                     "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"]
        assert list(filter(regex.search, test_list)) == ["DEBW108_h2o_o3_temp_test_shuffled.nc",
                                                         "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"]