Select Git revision
test_bootstraps.py
test_bootstraps.py 15.99 KiB
import logging
import os
import shutil
import mock
import numpy as np
import pytest
import xarray as xr
from mlair.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator
from mlair.data_handling.data_generator import DataGenerator
@pytest.fixture
def orig_generator(data_path):
return DataGenerator(data_path, 'AIRBASE', ['DEBW107', 'DEBW013'],
['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014,
statistics_per_var={"o3": "dma8eu", "temp": "maximum"})
@pytest.fixture
def data_path():
path = os.path.join(os.path.dirname(__file__), "data")
if not os.path.exists(path):
os.makedirs(path)
return path
class TestBootStrapGenerator:
@pytest.fixture
def hist(self, orig_generator):
return orig_generator.get_data_generator(0).get_transposed_history()
@pytest.fixture
def boot_gen(self, hist):
return BootStrapGenerator(20, hist, hist.expand_dims({"boots": [0]}) + 1, ["o3", "temp"], "o3")
def test_init(self, boot_gen, hist):
assert boot_gen.number_of_boots == 20
assert boot_gen.variables == ["o3", "temp"]
assert xr.testing.assert_equal(boot_gen.history_orig, hist) is None
assert xr.testing.assert_equal(boot_gen.history, hist.sel(variables=["temp"])) is None
assert xr.testing.assert_allclose(boot_gen.shuffled - 1,
hist.sel(variables="o3").expand_dims({"boots": [0]})) is None
def test_len(self, boot_gen):
assert len(boot_gen) == 20
def test_get_shuffled(self, boot_gen, hist):
shuffled = boot_gen._BootStrapGenerator__get_shuffled(0)
expected = hist.sel(variables=["o3"]).transpose("datetime", "window", "Stations", "variables") + 1
assert xr.testing.assert_equal(shuffled, expected) is None
def test_getitem(self, boot_gen, hist):
first_element = boot_gen[0]
assert xr.testing.assert_equal(first_element.sel(variables="temp"), hist.sel(variables="temp")) is None
assert xr.testing.assert_allclose(first_element.sel(variables="o3"), hist.sel(variables="o3") + 1) is None
def test_next(self, boot_gen, hist):
iter_obj = iter(boot_gen)
first_element = next(iter_obj)
assert xr.testing.assert_equal(first_element.sel(variables="temp"), hist.sel(variables="temp")) is None
assert xr.testing.assert_allclose(first_element.sel(variables="o3"), hist.sel(variables="o3") + 1) is None
with pytest.raises(KeyError):
next(iter_obj)
class TestCreateShuffledData:
@pytest.fixture
def shuffled_data(self, orig_generator, data_path):
return CreateShuffledData(orig_generator, 20, data_path)
@pytest.fixture
@mock.patch("mlair.data_handling.bootstraps.CreateShuffledData.create_shuffled_data", return_value=None)
def shuffled_data_no_creation(self, mock_create_shuffle_data, orig_generator, data_path):
return CreateShuffledData(orig_generator, 20, data_path)
@pytest.fixture
def shuffled_data_clean(self, shuffled_data_no_creation):
shutil.rmtree(shuffled_data_no_creation.bootstrap_path)
os.makedirs(shuffled_data_no_creation.bootstrap_path)
assert os.listdir(shuffled_data_no_creation.bootstrap_path) == [] # just to check for a clean working directory
return shuffled_data_no_creation
def test_init(self, shuffled_data_no_creation, data_path):
assert isinstance(shuffled_data_no_creation.data, DataGenerator)
assert shuffled_data_no_creation.number_of_bootstraps == 20
assert shuffled_data_no_creation.bootstrap_path == data_path
def test_create_shuffled_data_create_new(self, shuffled_data_clean, data_path, caplog):
caplog.set_level(logging.INFO)
shuffled_data_clean.data.data_path_tmp = data_path
assert shuffled_data_clean.create_shuffled_data() is None
assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data")
assert caplog.record_tuples[1] == ('root', logging.INFO, "create bootstap data for DEBW107")
assert caplog.record_tuples[3] == ('root', logging.INFO, "create bootstap data for DEBW013")
assert "DEBW107_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path)
assert "DEBW013_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path)
def test_create_shuffled_data_some_valid(self, shuffled_data_clean, data_path, caplog):
shuffled_data_clean.data.data_path_tmp = data_path
shuffled_data_clean.create_shuffled_data()
caplog.records.clear()
caplog.set_level(logging.INFO)
os.rename(os.path.join(data_path, "DEBW013_o3_temp_hist7_nboots20_shuffled.nc"),
os.path.join(data_path, "DEBW013_o3_temp_hist5_nboots30_shuffled.nc"))
shuffled_data_clean.create_shuffled_data()
assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data")
assert caplog.record_tuples[1] == ('root', logging.INFO, "create bootstap data for DEBW013")
assert "DEBW107_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path)
assert "DEBW013_o3_temp_hist7_nboots30_shuffled.nc" in os.listdir(data_path)
assert "DEBW013_o3_temp_hist5_nboots30_shuffled.nc" not in os.listdir(data_path)
def test_set_file_path(self, shuffled_data_no_creation):
res = shuffled_data_no_creation._set_file_path("DEBWtest", "o3_temp_wind", 10, 5)
assert "DEBWtest_o3_temp_wind_hist10_nboots5_shuffled.nc" in res
assert shuffled_data_no_creation.bootstrap_path in res
def test_valid_bootstrap_file_blank(self, shuffled_data_clean):
assert shuffled_data_clean.valid_bootstrap_file("DEBWtest", "o3_temp", 10) == (False, 20)
def test_valid_bootstrap_file_already_satisfied(self, shuffled_data_clean, data_path):
station, variables, window = "DEBWtest2", "o3_temp", 5
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None)
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots100_shuffled.dat"))
assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None)
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist10_nboots50_shuffled.dat"))
os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat"))
assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None)
def test_valid_bootstrap_file_reload_data_window(self, shuffled_data_clean, data_path):
station, variables, window = "DEBWtest2", "o3_temp", 20
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots100_shuffled.dat"))
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist10_nboots50_shuffled.dat"))
os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) # <- DEBWtest21
# need to reload data and therefore remove not fitting history size in all files for this station
assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 100)
assert len(os.listdir(data_path)) == 1 # keep only data from other station DEBWtest21
def test_valid_bootstrap_file_reload_data_boots(self, shuffled_data_clean, data_path):
station, variables, window = "DEBWtest2", "o3_temp", 5
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) # <- DEBWtest21
# reload because expanded boot number
shuffled_data_clean.number_of_bootstraps = 60
assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 60)
assert len(os.listdir(data_path)) == 1
def test_valid_bootstrap_file_reload_data_use_max_file_boot(self, shuffled_data_clean, data_path):
station, variables, window = "DEBWtest2", "o3_temp", 20
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots60_shuffled.dat"))
os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) # <- DEBWtest21
# reload because of expanded window size, but use maximum boot number from file names
assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 60)
def test_shuffle(self, shuffled_data_no_creation):
dummy = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
res = shuffled_data_no_creation.shuffle(dummy, chunks=(2, 3)).compute()
assert res.shape == dummy.shape
assert dummy.max() >= res.max()
assert dummy.min() <= res.min()
assert set(np.unique(res)).issubset({1, 2, 3})
class TestBootStraps:
@pytest.fixture
def bootstrap(self, orig_generator, data_path):
return BootStraps(orig_generator, data_path, 20)
@pytest.fixture
@mock.patch("mlair.data_handling.bootstraps.CreateShuffledData", return_value=None)
def bootstrap_no_shuffling(self, mock_create_shuffle_data, orig_generator, data_path):
shutil.rmtree(data_path)
return BootStraps(orig_generator, data_path, 20)
def test_init_no_shuffling(self, bootstrap_no_shuffling, data_path):
assert isinstance(bootstrap_no_shuffling, BootStraps)
assert bootstrap_no_shuffling.number_of_bootstraps == 20
assert bootstrap_no_shuffling.bootstrap_path == data_path
def test_init_with_shuffling(self, orig_generator, data_path, caplog):
caplog.set_level(logging.INFO)
BootStraps(orig_generator, data_path, 20)
assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data")
def test_stations(self, bootstrap_no_shuffling, orig_generator):
assert bootstrap_no_shuffling.stations == orig_generator.stations
def test_variables(self, bootstrap_no_shuffling, orig_generator):
assert bootstrap_no_shuffling.variables == orig_generator.variables
def test_window_history_size(self, bootstrap_no_shuffling, orig_generator):
assert bootstrap_no_shuffling.window_history_size == orig_generator.window_history_size
def test_get_generator(self, bootstrap, orig_generator):
station = bootstrap.stations[0]
var = bootstrap.variables[0]
var_others = bootstrap.variables[1:]
gen = bootstrap.get_generator(station, var)
assert isinstance(gen, BootStrapGenerator)
assert gen.number_of_boots == bootstrap.number_of_bootstraps
assert gen.variables == bootstrap.variables
expected = orig_generator.get_data_generator(station).get_transposed_history()
assert xr.testing.assert_equal(gen.history_orig, expected) is None
assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None
assert gen.shuffled.variables == "o3"
@mock.patch("mlair.data_handling.data_generator.DataGenerator._load_pickle_data", side_effect=FileNotFoundError)
def test_get_generator_different_generator(self, mock_load_pickle, data_path, orig_generator):
BootStraps(orig_generator, data_path, 20) # to create
orig_generator.window_history_size = 4
bootstrap = BootStraps(orig_generator, data_path, 20)
station = bootstrap.stations[0]
var = bootstrap.variables[0]
var_others = bootstrap.variables[1:]
gen = bootstrap.get_generator(station, var)
expected = orig_generator.get_data_generator(station, load_local_tmp_storage=False).get_transposed_history()
assert xr.testing.assert_equal(gen.history_orig, expected) is None
assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None
assert gen.shuffled.variables == "o3"
assert gen.shuffled.shape[:-1] == expected.shape[:-1]
assert gen.shuffled.shape[-1] == 20
def test_get_labels(self, bootstrap, orig_generator):
station = bootstrap.stations[0]
labels = bootstrap.get_labels(station)
labels_orig = orig_generator.get_data_generator(station).get_transposed_label()
assert labels.shape == (labels_orig.shape[0] * bootstrap.number_of_bootstraps, *labels_orig.shape[1:])
assert np.testing.assert_array_equal(labels[:labels_orig.shape[0], :], labels_orig.values) is None
def test_get_orig_prediction(self, bootstrap, data_path, orig_generator):
station = bootstrap.stations[0]
labels = orig_generator.get_data_generator(station).get_transposed_label()
predictions = labels.expand_dims({"type": ["CNN"]}, -1)
file_name = "test_prediction.nc"
predictions.to_netcdf(os.path.join(data_path, file_name))
res = bootstrap.get_orig_prediction(data_path, file_name)
assert (*res.shape, 1) == (predictions.shape[0] * bootstrap.number_of_bootstraps, *predictions.shape[1:])
assert np.testing.assert_array_equal(res[:predictions.shape[0], :], predictions.squeeze().values) is None
def test_load_shuffled_data(self, bootstrap, orig_generator):
station = bootstrap.stations[0]
hist = orig_generator.get_data_generator(station).get_transposed_history()
shuffled_data = bootstrap._load_shuffled_data(station, ["o3", "temp"])
assert isinstance(shuffled_data, xr.DataArray)
assert hist.shape[0] >= shuffled_data.shape[0] # longer window length lead to shorter datetime axis in shuffled
assert hist.shape[1] <= shuffled_data.shape[1] # longer window length in shuffled
assert hist.shape[2] == shuffled_data.shape[2]
assert hist.shape[3] <= shuffled_data.shape[3] # potentially more variables in shuffled
assert bootstrap.number_of_bootstraps == shuffled_data.shape[4]
assert shuffled_data.mean().compute()
assert np.testing.assert_almost_equal(shuffled_data.mean().compute(), hist.mean(), decimal=1) is None
assert shuffled_data.max() <= hist.max()
assert shuffled_data.min() >= hist.min()
def test_get_shuffled_data_file(self, bootstrap):
file_name = bootstrap._get_shuffled_data_file("DEBW107", ["o3"])
assert file_name == os.path.join(bootstrap.bootstrap_path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc")
def test_get_shuffled_data_file_not_found(self, bootstrap_no_shuffling, data_path):
bootstrap_no_shuffling.number_of_boots = 100
os.makedirs(data_path)
with pytest.raises(FileNotFoundError) as e:
bootstrap_no_shuffling._get_shuffled_data_file("DEBW107", ["o3"])
assert "Could not find a file to match pattern" in e.value.args[0]
def test_create_file_regex(self, bootstrap_no_shuffling):
regex = bootstrap_no_shuffling._create_file_regex("DEBW108", ["o3", "temp", "h2o"])
assert regex.match("DEBW108_h2o_hum_latent_o3_temp_h20_hist10_nboots10_shuffled.nc")
regex.match("DEBW108_h2o_hum_latent_o3_temp_hist10_shuffled.nc") is None
def test_filter_files(self, bootstrap_no_shuffling):
regex = bootstrap_no_shuffling._create_file_regex("DEBW108", ["o3", "temp", "h2o"])
test_list = ["DEBW108_o3_test23_test_shuffled.nc",
"DEBW107_o3_test23_test_shuffled.nc",
"DEBW108_o3_test23_test.nc",
"DEBW108_h2o_o3_temp_test_shuffled.nc",
"DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc",
"DEBW108_o3_temp_hist9_nboots20_shuffled.nc",
"DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"]
f = bootstrap_no_shuffling._filter_files
assert f(regex, test_list, 10, 10) is None
assert f(regex, test_list, 9, 10) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"
assert f(regex, test_list, 9, 20) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"