Skip to content
Snippets Groups Projects
Select Git revision
  • 0a52888526e1a038db7089e1d218fbe0987d4310
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

test_bootstraps.py

Blame
  • user avatar
    lukas leufen authored
    0a528885
    History
    test_bootstraps.py 15.99 KiB
    import logging
    import os
    import shutil
    
    import mock
    import numpy as np
    import pytest
    import xarray as xr
    
    from mlair.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator
    from mlair.data_handling.data_generator import DataGenerator
    
    
    @pytest.fixture
    def orig_generator(data_path):
        return DataGenerator(data_path, 'AIRBASE', ['DEBW107', 'DEBW013'],
                             ['o3', 'temp'], 'datetime', 'variables', 'o3', start=2010, end=2014,
                             statistics_per_var={"o3": "dma8eu", "temp": "maximum"})
    
    
    @pytest.fixture
    def data_path():
        path = os.path.join(os.path.dirname(__file__), "data")
        if not os.path.exists(path):
            os.makedirs(path)
        return path
    
    
    class TestBootStrapGenerator:
    
        @pytest.fixture
        def hist(self, orig_generator):
            return orig_generator.get_data_generator(0).get_transposed_history()
    
        @pytest.fixture
        def boot_gen(self, hist):
            return BootStrapGenerator(20, hist, hist.expand_dims({"boots": [0]}) + 1, ["o3", "temp"], "o3")
    
        def test_init(self, boot_gen, hist):
            assert boot_gen.number_of_boots == 20
            assert boot_gen.variables == ["o3", "temp"]
            assert xr.testing.assert_equal(boot_gen.history_orig, hist) is None
            assert xr.testing.assert_equal(boot_gen.history, hist.sel(variables=["temp"])) is None
            assert xr.testing.assert_allclose(boot_gen.shuffled - 1,
                                              hist.sel(variables="o3").expand_dims({"boots": [0]})) is None
    
        def test_len(self, boot_gen):
            assert len(boot_gen) == 20
    
        def test_get_shuffled(self, boot_gen, hist):
            shuffled = boot_gen._BootStrapGenerator__get_shuffled(0)
            expected = hist.sel(variables=["o3"]).transpose("datetime", "window", "Stations", "variables") + 1
            assert xr.testing.assert_equal(shuffled, expected) is None
    
        def test_getitem(self, boot_gen, hist):
            first_element = boot_gen[0]
            assert xr.testing.assert_equal(first_element.sel(variables="temp"), hist.sel(variables="temp")) is None
            assert xr.testing.assert_allclose(first_element.sel(variables="o3"), hist.sel(variables="o3") + 1) is None
    
        def test_next(self, boot_gen, hist):
            iter_obj = iter(boot_gen)
            first_element = next(iter_obj)
            assert xr.testing.assert_equal(first_element.sel(variables="temp"), hist.sel(variables="temp")) is None
            assert xr.testing.assert_allclose(first_element.sel(variables="o3"), hist.sel(variables="o3") + 1) is None
            with pytest.raises(KeyError):
                next(iter_obj)
    
    
    class TestCreateShuffledData:
    
        @pytest.fixture
        def shuffled_data(self, orig_generator, data_path):
            return CreateShuffledData(orig_generator, 20, data_path)
    
        @pytest.fixture
        @mock.patch("mlair.data_handling.bootstraps.CreateShuffledData.create_shuffled_data", return_value=None)
        def shuffled_data_no_creation(self, mock_create_shuffle_data, orig_generator, data_path):
            return CreateShuffledData(orig_generator, 20, data_path)
    
        @pytest.fixture
        def shuffled_data_clean(self, shuffled_data_no_creation):
            shutil.rmtree(shuffled_data_no_creation.bootstrap_path)
            os.makedirs(shuffled_data_no_creation.bootstrap_path)
            assert os.listdir(shuffled_data_no_creation.bootstrap_path) == []  # just to check for a clean working directory
            return shuffled_data_no_creation
    
        def test_init(self, shuffled_data_no_creation, data_path):
            assert isinstance(shuffled_data_no_creation.data, DataGenerator)
            assert shuffled_data_no_creation.number_of_bootstraps == 20
            assert shuffled_data_no_creation.bootstrap_path == data_path
    
        def test_create_shuffled_data_create_new(self, shuffled_data_clean, data_path, caplog):
            caplog.set_level(logging.INFO)
            shuffled_data_clean.data.data_path_tmp = data_path
            assert shuffled_data_clean.create_shuffled_data() is None
            assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data")
            assert caplog.record_tuples[1] == ('root', logging.INFO, "create bootstap data for DEBW107")
            assert caplog.record_tuples[3] == ('root', logging.INFO, "create bootstap data for DEBW013")
            assert "DEBW107_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path)
            assert "DEBW013_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path)
    
        def test_create_shuffled_data_some_valid(self, shuffled_data_clean, data_path, caplog):
            shuffled_data_clean.data.data_path_tmp = data_path
            shuffled_data_clean.create_shuffled_data()
            caplog.records.clear()
            caplog.set_level(logging.INFO)
            os.rename(os.path.join(data_path, "DEBW013_o3_temp_hist7_nboots20_shuffled.nc"),
                      os.path.join(data_path, "DEBW013_o3_temp_hist5_nboots30_shuffled.nc"))
            shuffled_data_clean.create_shuffled_data()
            assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data")
            assert caplog.record_tuples[1] == ('root', logging.INFO, "create bootstap data for DEBW013")
            assert "DEBW107_o3_temp_hist7_nboots20_shuffled.nc" in os.listdir(data_path)
            assert "DEBW013_o3_temp_hist7_nboots30_shuffled.nc" in os.listdir(data_path)
            assert "DEBW013_o3_temp_hist5_nboots30_shuffled.nc" not in os.listdir(data_path)
    
        def test_set_file_path(self, shuffled_data_no_creation):
            res = shuffled_data_no_creation._set_file_path("DEBWtest", "o3_temp_wind", 10, 5)
            assert "DEBWtest_o3_temp_wind_hist10_nboots5_shuffled.nc" in res
            assert shuffled_data_no_creation.bootstrap_path in res
    
        def test_valid_bootstrap_file_blank(self, shuffled_data_clean):
            assert shuffled_data_clean.valid_bootstrap_file("DEBWtest", "o3_temp", 10) == (False, 20)
    
        def test_valid_bootstrap_file_already_satisfied(self, shuffled_data_clean, data_path):
            station, variables, window = "DEBWtest2", "o3_temp", 5
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
            assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None)
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots100_shuffled.dat"))
            assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None)
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist10_nboots50_shuffled.dat"))
            os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat"))
            assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (True, None)
    
        def test_valid_bootstrap_file_reload_data_window(self, shuffled_data_clean, data_path):
            station, variables, window = "DEBWtest2", "o3_temp", 20
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots100_shuffled.dat"))
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist10_nboots50_shuffled.dat"))
            os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat"))  # <- DEBWtest21
            #  need to reload data and therefore remove not fitting history size in all files for this station
            assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 100)
            assert len(os.listdir(data_path)) == 1  # keep only data from other station DEBWtest21
    
        def test_valid_bootstrap_file_reload_data_boots(self, shuffled_data_clean, data_path):
            station, variables, window = "DEBWtest2", "o3_temp", 5
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
            os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat"))  # <- DEBWtest21
            # reload because expanded boot number
            shuffled_data_clean.number_of_bootstraps = 60
            assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 60)
            assert len(os.listdir(data_path)) == 1
    
        def test_valid_bootstrap_file_reload_data_use_max_file_boot(self, shuffled_data_clean, data_path):
            station, variables, window = "DEBWtest2", "o3_temp", 20
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots50_shuffled.dat"))
            os.mknod(os.path.join(data_path, f"{station}_{variables}_hist5_nboots60_shuffled.dat"))
            os.mknod(os.path.join(data_path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat"))  # <- DEBWtest21
            # reload because of expanded window size, but use maximum boot number from file names
            assert shuffled_data_clean.valid_bootstrap_file(station, variables, window) == (False, 60)
    
        def test_shuffle(self, shuffled_data_no_creation):
            dummy = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]])
            res = shuffled_data_no_creation.shuffle(dummy, chunks=(2, 3)).compute()
            assert res.shape == dummy.shape
            assert dummy.max() >= res.max()
            assert dummy.min() <= res.min()
            assert set(np.unique(res)).issubset({1, 2, 3})
    
    
    class TestBootStraps:
    
        @pytest.fixture
        def bootstrap(self, orig_generator, data_path):
            return BootStraps(orig_generator, data_path, 20)
    
        @pytest.fixture
        @mock.patch("mlair.data_handling.bootstraps.CreateShuffledData", return_value=None)
        def bootstrap_no_shuffling(self, mock_create_shuffle_data, orig_generator, data_path):
            shutil.rmtree(data_path)
            return BootStraps(orig_generator, data_path, 20)
    
        def test_init_no_shuffling(self, bootstrap_no_shuffling, data_path):
            assert isinstance(bootstrap_no_shuffling, BootStraps)
            assert bootstrap_no_shuffling.number_of_bootstraps == 20
            assert bootstrap_no_shuffling.bootstrap_path == data_path
    
        def test_init_with_shuffling(self, orig_generator, data_path, caplog):
            caplog.set_level(logging.INFO)
            BootStraps(orig_generator, data_path, 20)
            assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data")
    
        def test_stations(self, bootstrap_no_shuffling, orig_generator):
            assert bootstrap_no_shuffling.stations == orig_generator.stations
    
        def test_variables(self, bootstrap_no_shuffling, orig_generator):
            assert bootstrap_no_shuffling.variables == orig_generator.variables
    
        def test_window_history_size(self, bootstrap_no_shuffling, orig_generator):
            assert bootstrap_no_shuffling.window_history_size == orig_generator.window_history_size
    
        def test_get_generator(self, bootstrap, orig_generator):
            station = bootstrap.stations[0]
            var = bootstrap.variables[0]
            var_others = bootstrap.variables[1:]
            gen = bootstrap.get_generator(station, var)
            assert isinstance(gen, BootStrapGenerator)
            assert gen.number_of_boots == bootstrap.number_of_bootstraps
            assert gen.variables == bootstrap.variables
            expected = orig_generator.get_data_generator(station).get_transposed_history()
            assert xr.testing.assert_equal(gen.history_orig, expected) is None
            assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None
            assert gen.shuffled.variables == "o3"
    
        @mock.patch("mlair.data_handling.data_generator.DataGenerator._load_pickle_data", side_effect=FileNotFoundError)
        def test_get_generator_different_generator(self, mock_load_pickle, data_path, orig_generator):
            BootStraps(orig_generator, data_path, 20)  # to create
            orig_generator.window_history_size = 4
            bootstrap = BootStraps(orig_generator, data_path, 20)
            station = bootstrap.stations[0]
            var = bootstrap.variables[0]
            var_others = bootstrap.variables[1:]
            gen = bootstrap.get_generator(station, var)
            expected = orig_generator.get_data_generator(station, load_local_tmp_storage=False).get_transposed_history()
            assert xr.testing.assert_equal(gen.history_orig, expected) is None
            assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None
            assert gen.shuffled.variables == "o3"
            assert gen.shuffled.shape[:-1] == expected.shape[:-1]
            assert gen.shuffled.shape[-1] == 20
    
        def test_get_labels(self, bootstrap, orig_generator):
            station = bootstrap.stations[0]
            labels = bootstrap.get_labels(station)
            labels_orig = orig_generator.get_data_generator(station).get_transposed_label()
            assert labels.shape == (labels_orig.shape[0] * bootstrap.number_of_bootstraps, *labels_orig.shape[1:])
            assert np.testing.assert_array_equal(labels[:labels_orig.shape[0], :], labels_orig.values) is None
    
        def test_get_orig_prediction(self, bootstrap, data_path, orig_generator):
            station = bootstrap.stations[0]
            labels = orig_generator.get_data_generator(station).get_transposed_label()
            predictions = labels.expand_dims({"type": ["CNN"]}, -1)
            file_name = "test_prediction.nc"
            predictions.to_netcdf(os.path.join(data_path, file_name))
            res = bootstrap.get_orig_prediction(data_path, file_name)
            assert (*res.shape, 1) == (predictions.shape[0] * bootstrap.number_of_bootstraps, *predictions.shape[1:])
            assert np.testing.assert_array_equal(res[:predictions.shape[0], :], predictions.squeeze().values) is None
    
        def test_load_shuffled_data(self, bootstrap, orig_generator):
            station = bootstrap.stations[0]
            hist = orig_generator.get_data_generator(station).get_transposed_history()
            shuffled_data = bootstrap._load_shuffled_data(station, ["o3", "temp"])
            assert isinstance(shuffled_data, xr.DataArray)
            assert hist.shape[0] >= shuffled_data.shape[0]  # longer window length lead to shorter datetime axis in shuffled
            assert hist.shape[1] <= shuffled_data.shape[1]  # longer window length in shuffled
            assert hist.shape[2] == shuffled_data.shape[2]
            assert hist.shape[3] <= shuffled_data.shape[3]  # potentially more variables in shuffled
            assert bootstrap.number_of_bootstraps == shuffled_data.shape[4]
            assert shuffled_data.mean().compute()
            assert np.testing.assert_almost_equal(shuffled_data.mean().compute(), hist.mean(), decimal=1) is None
            assert shuffled_data.max() <= hist.max()
            assert shuffled_data.min() >= hist.min()
    
        def test_get_shuffled_data_file(self, bootstrap):
            file_name = bootstrap._get_shuffled_data_file("DEBW107", ["o3"])
            assert file_name == os.path.join(bootstrap.bootstrap_path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc")
    
        def test_get_shuffled_data_file_not_found(self, bootstrap_no_shuffling, data_path):
            bootstrap_no_shuffling.number_of_boots = 100
            os.makedirs(data_path)
            with pytest.raises(FileNotFoundError) as e:
                bootstrap_no_shuffling._get_shuffled_data_file("DEBW107", ["o3"])
            assert "Could not find a file to match pattern" in e.value.args[0]
    
        def test_create_file_regex(self, bootstrap_no_shuffling):
            regex = bootstrap_no_shuffling._create_file_regex("DEBW108", ["o3", "temp", "h2o"])
            assert regex.match("DEBW108_h2o_hum_latent_o3_temp_h20_hist10_nboots10_shuffled.nc")
            regex.match("DEBW108_h2o_hum_latent_o3_temp_hist10_shuffled.nc") is None
    
        def test_filter_files(self, bootstrap_no_shuffling):
            regex = bootstrap_no_shuffling._create_file_regex("DEBW108", ["o3", "temp", "h2o"])
            test_list = ["DEBW108_o3_test23_test_shuffled.nc",
                         "DEBW107_o3_test23_test_shuffled.nc",
                         "DEBW108_o3_test23_test.nc",
                         "DEBW108_h2o_o3_temp_test_shuffled.nc",
                         "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc",
                         "DEBW108_o3_temp_hist9_nboots20_shuffled.nc",
                         "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"]
            f = bootstrap_no_shuffling._filter_files
            assert f(regex, test_list, 10, 10) is None
            assert f(regex, test_list, 9, 10) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"
            assert f(regex, test_list, 9, 20) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"