import logging import os import shutil import mock import numpy as np import pytest import xarray as xr from mlair.data_handler.bootstraps import BootStraps from src.data_handler import DataPrepJoin @pytest.fixture def orig_generator(data_path): return DataGenerator(data_path, ['DEBW107', 'DEBW013'], ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014, statistics_per_var={"o3": "dma8eu", "temp": "maximum"}, data_preparation=DataPrepJoin) @pytest.fixture def data_path(): path = os.path.join(os.path.dirname(__file__), "data") if not os.path.exists(path): os.makedirs(path) return path class TestBootStrapGenerator: @pytest.fixture def hist(self, orig_generator): return orig_generator.get_data_generator(0).get_transposed_history() @pytest.fixture def boot_gen(self, hist): return BootStrapGenerator(20, hist, hist.expand_dims({"boots": [0]}) + 1, ["o3", "temp"], "o3") def test_init(self, boot_gen, hist): assert boot_gen.number_of_boots == 20 assert boot_gen.variables == ["o3", "temp"] assert xr.testing.assert_equal(boot_gen.history_orig, hist) is None assert xr.testing.assert_equal(boot_gen.history, hist.sel(variables=["temp"])) is None assert xr.testing.assert_allclose(boot_gen.shuffled - 1, hist.sel(variables="o3").expand_dims({"boots": [0]})) is None def test_len(self, boot_gen): assert len(boot_gen) == 20 def test_get_shuffled(self, boot_gen, hist): shuffled = boot_gen._BootStrapGenerator__get_shuffled(0) expected = hist.sel(variables=["o3"]).transpose("datetime", "window", "Stations", "variables") + 1 assert xr.testing.assert_equal(shuffled, expected) is None def test_getitem(self, boot_gen, hist): first_element = boot_gen[0] assert xr.testing.assert_equal(first_element.sel(variables="temp"), hist.sel(variables="temp")) is None assert xr.testing.assert_allclose(first_element.sel(variables="o3"), hist.sel(variables="o3") + 1) is None def test_next(self, boot_gen, hist): iter_obj = iter(boot_gen) first_element = next(iter_obj) assert xr.testing.assert_equal(first_element.sel(variables="temp"), hist.sel(variables="temp")) is None assert xr.testing.assert_allclose(first_element.sel(variables="o3"), hist.sel(variables="o3") + 1) is None with pytest.raises(KeyError): next(iter_obj) class TestCreateShuffledData: @pytest.fixture def shuffled_data(self, orig_generator, data_path): return CreateShuffledData(orig_generator, 20, data_path) @pytest.fixture @mock.patch("mlair.data_handling.bootstraps.CreateShuffledData.create_shuffled_data", return_value=None) def shuffled_data_no_creation(self, mock_create_shuffle_data, orig_generator, data_path): return CreateShuffledData(orig_generator, 20, data_path) @pytest.fixture def shuffled_data_clean(self, shuffled_data_no_creation): shutil.rmtree(shuffled_data_no_creation.bootstrap_path) os.makedirs(shuffled_data_no_creation.bootstrap_path) assert os.listdir(shuffled_data_no_creation.bootstrap_path) == [] # just to check for a clean working directory return shuffled_data_no_creation def test_init(self, shuffled_data_no_creation, data_path): assert isinstance(shuffled_data_no_creation.data, DataGenerator) assert shuffled_data_no_creation.number_of_bootstraps == 20 assert shuffled_data_no_creation.bootstrap_path == data_path def test_create_shuffled_data_create_new(self, shuffled_data_clean, data_path, caplog): caplog.set_level(logging.INFO) shuffled_data_clean.data.data_path_tmp = data_path assert shuffled_data_clean.create_shuffled_data() is None assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data") assert caplog.record_tuples[1] == ('root', logging.INFO, "create bootstap data for DEBW107") assert caplog.record_tuples[3] == ('root', logging.INFO, "create bootstap data for DEBW013") assert "DEBW107_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path) assert "DEBW013_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path) def test_create_shuffled_data_some_valid(self, shuffled_data_clean, data_path, caplog): shuffled_data_clean.data.data_path_tmp = data_path shuffled_data_clean.create_shuffled_data() caplog.records.clear() caplog.set_level(logging.INFO) os.rename(os.path.join(data_path, "DEBW013_o3_temp_hist7_nboots20_shuffled.nc"), os.path.join(data_path, "DEBW013_o3_temp_hist5_nboots30_shuffled.nc")) shuffled_data_clean.create_shuffled_data() assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data") assert caplog.record_tuples[1] == ('root', logging.INFO, "create bootstap data for DEBW013") assert "DEBW107_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path) assert "DEBW013_o3_temp_hist7_nboots30_shuffled.nc" in os.listdir(data_path) assert "DEBW013_o3_temp_hist5_nboots30_shuffled.nc" not in os.listdir(data_path) def test_set_file_path(self, shuffled_data_no_creation): res = shuffled_data_no_creation._set_file_path("DEBWtest", "o3_temp_wind", 10, 5) assert "DEBWtest_o3_temp_wind_hist10_nboots5_shuffled.nc" in res assert shuffled_data_no_creation.bootstrap_path in res def test_valid_bootstrap_file_blank(self, shuffled_data_clean): assert shuffled_data_clean.valid_bootstrap_file("DEBWtest", "o3_temp", 10) == (False, 20) def test_valid_bootstrap_file_already_satisfied(self, shuffled_data_clean, data_path): station, variables, window = "DEBWtest2", "o3_temp", 5 os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat")) assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None) os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots100_shuffled.dat")) assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None) os.mknod(os.path.join(data_path, f"{station}_{variables}_hist10_nboots50_shuffled.dat")) os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None) def test_valid_bootstrap_file_reload_data_window(self, shuffled_data_clean, data_path): station, variables, window = "DEBWtest2", "o3_temp", 20 os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat")) os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots100_shuffled.dat")) os.mknod(os.path.join(data_path, f"{station}_{variables}_hist10_nboots50_shuffled.dat")) os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) # <- DEBWtest21 # need to reload data and therefore remove not fitting history size in all files for this station assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 100) assert len(os.listdir(data_path)) == 1 # keep only data from other station DEBWtest21 def test_valid_bootstrap_file_reload_data_boots(self, shuffled_data_clean, data_path): station, variables, window = "DEBWtest2", "o3_temp", 5 os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat")) os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) # <- DEBWtest21 # reload because expanded boot number shuffled_data_clean.number_of_bootstraps = 60 assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 60) assert len(os.listdir(data_path)) == 1 def test_valid_bootstrap_file_reload_data_use_max_file_boot(self, shuffled_data_clean, data_path): station, variables, window = "DEBWtest2", "o3_temp", 20 os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat")) os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots60_shuffled.dat")) os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) # <- DEBWtest21 # reload because of expanded window size, but use maximum boot number from file names assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 60) def test_shuffle(self, shuffled_data_no_creation): dummy = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) res = shuffled_data_no_creation.shuffle(dummy, chunks=(2, 3)).compute() assert res.shape == dummy.shape assert dummy.max() >= res.max() assert dummy.min() <= res.min() assert set(np.unique(res)).issubset({1, 2, 3}) class TestBootStraps: @pytest.fixture def bootstrap(self, orig_generator, data_path): return BootStraps(orig_generator, data_path, 20) @pytest.fixture @mock.patch("mlair.data_handling.bootstraps.CreateShuffledData", return_value=None) def bootstrap_no_shuffling(self, mock_create_shuffle_data, orig_generator, data_path): shutil.rmtree(data_path) return BootStraps(orig_generator, data_path, 20) def test_init_no_shuffling(self, bootstrap_no_shuffling, data_path): assert isinstance(bootstrap_no_shuffling, BootStraps) assert bootstrap_no_shuffling.number_of_bootstraps == 20 assert bootstrap_no_shuffling.bootstrap_path == data_path def test_init_with_shuffling(self, orig_generator, data_path, caplog): caplog.set_level(logging.INFO) BootStraps(orig_generator, data_path, 20) assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data") def test_stations(self, bootstrap_no_shuffling, orig_generator): assert bootstrap_no_shuffling.stations == orig_generator.stations def test_variables(self, bootstrap_no_shuffling, orig_generator): assert bootstrap_no_shuffling.variables == orig_generator.variables def test_window_history_size(self, bootstrap_no_shuffling, orig_generator): assert bootstrap_no_shuffling.window_history_size == orig_generator.window_history_size def test_get_generator(self, bootstrap, orig_generator): station = bootstrap.stations[0] var = bootstrap.variables[0] var_others = bootstrap.variables[1:] gen = bootstrap.get_generator(station, var) assert isinstance(gen, BootStrapGenerator) assert gen.number_of_boots == bootstrap.number_of_bootstraps assert gen.variables == bootstrap.variables expected = orig_generator.get_data_generator(station).get_transposed_history() assert xr.testing.assert_equal(gen.history_orig, expected) is None assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None assert gen.shuffled.variables == "o3" @mock.patch("mlair.data_handling.data_generator.DataGenerator._load_pickle_data", side_effect=FileNotFoundError) def test_get_generator_different_generator(self, mock_load_pickle, data_path, orig_generator): BootStraps(orig_generator, data_path, 20) # to create orig_generator.window_history_size = 4 bootstrap = BootStraps(orig_generator, data_path, 20) station = bootstrap.stations[0] var = bootstrap.variables[0] var_others = bootstrap.variables[1:] gen = bootstrap.get_generator(station, var) expected = orig_generator.get_data_generator(station, load_local_tmp_storage=False).get_transposed_history() assert xr.testing.assert_equal(gen.history_orig, expected) is None assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None assert gen.shuffled.variables == "o3" assert gen.shuffled.shape[:-1] == expected.shape[:-1] assert gen.shuffled.shape[-1] == 20 def test_get_labels(self, bootstrap, orig_generator): station = bootstrap.stations[0] labels = bootstrap.get_labels(station) labels_orig = orig_generator.get_data_generator(station).get_transposed_label() assert labels.shape == (labels_orig.shape[0] * bootstrap.number_of_bootstraps, *labels_orig.shape[1:]) assert np.testing.assert_array_equal(labels[:labels_orig.shape[0], :], labels_orig.values) is None def test_get_orig_prediction(self, bootstrap, data_path, orig_generator): station = bootstrap.stations[0] labels = orig_generator.get_data_generator(station).get_transposed_label() predictions = labels.expand_dims({"type": ["CNN"]}, -1) file_name = "test_prediction.nc" predictions.to_netcdf(os.path.join(data_path, file_name)) res = bootstrap.get_orig_prediction(data_path, file_name) assert (*res.shape, 1) == (predictions.shape[0] * bootstrap.number_of_bootstraps, *predictions.shape[1:]) assert np.testing.assert_array_equal(res[:predictions.shape[0], :], predictions.squeeze().values) is None def test_load_shuffled_data(self, bootstrap, orig_generator): station = bootstrap.stations[0] hist = orig_generator.get_data_generator(station).get_transposed_history() shuffled_data = bootstrap._load_shuffled_data(station, ["o3", "temp"]) assert isinstance(shuffled_data, xr.DataArray) assert hist.shape[0] >= shuffled_data.shape[0] # longer window length lead to shorter datetime axis in shuffled assert hist.shape[1] <= shuffled_data.shape[1] # longer window length in shuffled assert hist.shape[2] == shuffled_data.shape[2] assert hist.shape[3] <= shuffled_data.shape[3] # potentially more variables in shuffled assert bootstrap.number_of_bootstraps == shuffled_data.shape[4] assert shuffled_data.mean().compute() assert np.testing.assert_almost_equal(shuffled_data.mean().compute(), hist.mean(), decimal=1) is None assert shuffled_data.max() <= hist.max() assert shuffled_data.min() >= hist.min() def test_get_shuffled_data_file(self, bootstrap): file_name = bootstrap._get_shuffled_data_file("DEBW107", ["o3"]) assert file_name == os.path.join(bootstrap.bootstrap_path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc") def test_get_shuffled_data_file_not_found(self, bootstrap_no_shuffling, data_path): bootstrap_no_shuffling.number_of_boots = 100 os.makedirs(data_path) with pytest.raises(FileNotFoundError) as e: bootstrap_no_shuffling._get_shuffled_data_file("DEBW107", ["o3"]) assert "Could not find a file to match pattern" in e.value.args[0] def test_create_file_regex(self, bootstrap_no_shuffling): regex = bootstrap_no_shuffling._create_file_regex("DEBW108", ["o3", "temp", "h2o"]) assert regex.match("DEBW108_h2o_hum_latent_o3_temp_h20_hist10_nboots10_shuffled.nc") regex.match("DEBW108_h2o_hum_latent_o3_temp_hist10_shuffled.nc") is None def test_filter_files(self, bootstrap_no_shuffling): regex = bootstrap_no_shuffling._create_file_regex("DEBW108", ["o3", "temp", "h2o"]) test_list = ["DEBW108_o3_test23_test_shuffled.nc", "DEBW107_o3_test23_test_shuffled.nc", "DEBW108_o3_test23_test.nc", "DEBW108_h2o_o3_temp_test_shuffled.nc", "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc", "DEBW108_o3_temp_hist9_nboots20_shuffled.nc", "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"] f = bootstrap_no_shuffling._filter_files assert f(regex, test_list, 10, 10) is None assert f(regex, test_list, 9, 10) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc" assert f(regex, test_list, 9, 20) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"