from src.data_handling.bootstraps import BootStraps, BootStrapGenerator from src.data_handling.data_generator import DataGenerator import pytest import os import numpy as np import xarray as xr class TestBootstraps: @pytest.fixture def path(self): path = os.path.join(os.path.dirname(__file__), "data") if not os.path.exists(path): os.makedirs(path) return path @pytest.fixture def boot_no_init(self, path): obj = object.__new__(BootStraps) super(BootStraps, obj).__init__() obj.number_bootstraps = 50 obj.bootstrap_path = path return obj def test_valid_bootstrap_file(self, path, boot_no_init): station = "TESTSTATION" variables = "var1_var2_var3" window = 5 # empty case assert len(os.listdir(path)) == 0 assert boot_no_init.valid_bootstrap_file(station, variables, window) == (False, 50) # different cases, where files with bigger range are existing os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots50_shuffled.dat")) assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None) os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots100_shuffled.dat")) assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None) os.mknod(os.path.join(path, f"{station}_{variables}_hist10_nboots50_shuffled.dat")) os.mknod(os.path.join(path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None) # need to reload data and therefore remove not fitting files for this station assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 100) assert len(os.listdir(path)) == 1 # reload because expanded boot number os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots50_shuffled.dat")) boot_no_init.number_bootstraps = 60 assert boot_no_init.valid_bootstrap_file(station, variables, window) == (False, 60) assert len(os.listdir(path)) == 1 # reload because of expanded window size, but use maximum boot number from file names os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots60_shuffled.dat")) boot_no_init.number_bootstraps = 50 assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 60) def test_shuffle_single_variale(self, boot_no_init): data = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) res = boot_no_init.shuffle_single_variable(data, chunks=(2, 3)).compute() assert res.shape == data.shape assert res.max() == data.max() assert res.min() == data.min() assert set(np.unique(res)).issubset({1, 2, 3}) def test_create_shuffled_data(self): pass class TestBootstrapGenerator: @pytest.fixture def orig_generator(self): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'], ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014) @pytest.fixture def boot_gen(self, orig_generator): path = os.path.join(os.path.dirname(__file__), 'data') dummy_content = xr.DataArray([1, 2, 3], dims="dummy") dummy_content.to_netcdf(os.path.join(path, "DEBW107_o3_temp_shuffled.nc")) dummy_content.to_netcdf(os.path.join(path, "DEBW013_o3_temp_shuffled.nc")) return BootStrapGenerator(orig_generator, 20, path) def test_init(self, orig_generator): gen = BootStrapGenerator(orig_generator, 20, os.path.join(os.path.dirname(__file__), 'data')) assert gen.stations == ["DEBW107", "DEBW013"] assert gen.variables == ["o3", "temp"] assert gen.number_of_boots == 20 assert gen.bootstrap_path == os.path.join(os.path.dirname(__file__), 'data') def test_len(self, boot_gen): assert len(boot_gen) == 80 def test_get_generator(self, boot_gen): pass def test_get_bootstrap_meta(self, boot_gen): pass def test_get_labels(self, boot_gen): pass def test_get_orig_prediction(self, boot_gen): pass def test_load_shuffled_data(self, boot_gen): shuffled_data = boot_gen.load_shuffled_data("DEBW107", ["o3", "temp"]) assert isinstance(shuffled_data, xr.DataArray) assert all(shuffled_data.compute().values == [1, 2, 3]) def test_create_file_regex(self, boot_gen): regex = boot_gen.create_file_regex("DEBW108", ["o3", "temp", "h2o"]) test_list = ["DEBW108_o3_test23_test_shuffled.nc", "DEBW107_o3_test23_test_shuffled.nc", "DEBW108_o3_test23_test.nc", "DEBW108_h2o_o3_temp_test_shuffled.nc", "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"] assert list(filter(regex.search, test_list)) == ["DEBW108_h2o_o3_temp_test_shuffled.nc", "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc"]