diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 3e86ed96b70f40ac77c9d3f7df1b774a2f56060d..3fdabe4c9ef6faafa447967b86cce1aaab6d2bcf 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -165,12 +165,12 @@ class BootStraps: def get_labels(self, key: Union[str, int]): """ - Reepats labels for given key by the number of boots and yield it one by one. + Repeats labels for given key by the number of boots and yield it one by one. :param key: key of station (either station name as string or the position in generator as integer) :return: yields labels for length of boots """ labels = self.data[key][1] - return labels.data.repeat(self.number_of_bootstraps, axis=0) + return np.tile(labels.data, (self.number_of_bootstraps, 1)) def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN"): """ @@ -182,7 +182,7 @@ class BootStraps: """ file = os.path.join(path, file_name) prediction = xr.open_dataarray(file).sel(type=prediction_name).squeeze() - vals = prediction.data.repeat(self.number_of_bootstraps, axis=0) + vals = np.tile(prediction.data, (self.number_of_bootstraps, 1)) return vals[~np.isnan(vals).any(axis=1), :] def _load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray: diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index 46f563bb2eebab35b879e5e07327ceee9ed5c6cb..45d8c86fb24547e870b190885abeda82aacb882c 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -17,7 +17,7 @@ import xarray as xr ### old ### -class TestBootStraps: +class TestBootStrapsOld: @pytest.fixture def boot_gen(self, orig_generator): @@ -29,55 +29,6 @@ class TestBootStraps: dummy_content.to_netcdf(os.path.join(path, "forecasts_norm_DEBW107_test.nc")) return BootStraps(orig_generator, path, 20) - @pytest.fixture - def boot_gen_real(self, orig_generator): - path = os.path.join(os.path.dirname(__file__), 'data') - for hist, _ in orig_generator: - hist = hist.expand_dims({"boots": [1]}) - station = orig_generator.get_station_key(orig_generator._iterator-1) - hist.to_netcdf(os.path.join(path, f"{station}_o3_temp_hist7_nboots20_shuffled.nc")) - return BootStrapGenerator(orig_generator, 20, path) - - def test_init(self, orig_generator): - gen = BootStraps(orig_generator, os.path.join(os.path.dirname(__file__), 'data'), 20) - # assert gen.stations == ["DEBW107", "DEBW013"] - # assert gen.variables == ["o3", "temp"] - assert isinstance(gen.data, DataGenerator) - assert gen.number_of_bootstraps == 20 - assert gen.bootstrap_path == os.path.join(os.path.dirname(__file__), 'data') - - def test_get_generator(self, boot_gen_real, orig_generator): - res = boot_gen_real.get_generator("DEBW107", "o3") - hist = orig_generator.get_data_generator("DEBW107").get_transposed_history() - assert xr.testing.assert_equal(res[0], hist) is None - label = orig_generator.get_data_generator("DEBW107").get_transposed_label() - assert xr.testing.assert_equal(res[1], label) is None - assert isinstance(res[2], typing.Callable) - assert res[3] == 20 - - def test_get_generator_station_var_wise(self, boot_gen, orig_generator): - res = boot_gen.get_generator("DEBW107", "o3") - hist = orig_generator.get_data_generator("DEBW107").get_transposed_history() - assert xr.testing.assert_equal(res[0], hist) is None - label = orig_generator.get_data_generator("DEBW107").get_transposed_label() - assert xr.testing.assert_equal(res[1], label) is None - assert isinstance(res[2], typing.Callable) - assert res[3] == 20 - - def test_get_bootstrap_station_var_wise_meta(self, boot_gen): - meta = boot_gen.get_bootstrap_meta_station_var_wise("DEBW107", "o3") - labels = boot_gen.orig_generator.get_data_generator("DEBW107").get_transposed_label().shape[0] - assert np.shape(meta) == (labels * boot_gen.number_of_boots, 2) - assert np.testing.assert_array_equal(np.unique(meta), ["DEBW107", "o3"]) is None - - def test_get_labels(self, boot_gen): - res = [] - for label in boot_gen.get_labels("DEBW107"): - res.append(label) - assert len(res) == boot_gen.number_of_boots - assert xr.testing.assert_equal(res[0], res[-1]) is None - assert PyTestAllEqual(res).is_true() - def test_get_orig_prediction(self, boot_gen): path = boot_gen.orig_generator.data_path res = [] @@ -91,29 +42,6 @@ class TestBootStraps: assert isinstance(shuffled_data, xr.DataArray) assert all(shuffled_data.compute().values == [1, 2, 3]) - def test_get_shuffled_data_file(self, boot_gen): - file_name = boot_gen._get_shuffled_data_file("DEBW107", ["o3"]) - assert file_name == os.path.join(boot_gen.bootstrap_path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc") - - def test_get_shuffled_data_file_not_found(self, boot_gen): - boot_gen.number_of_boots = 100 - with pytest.raises(FileNotFoundError) as e: - boot_gen._get_shuffled_data_file("DEBW107", ["o3"]) - assert "Could not find a file to match pattern" in e.value.args[0] - - def test_create_file_regex(self, boot_gen): - regex = boot_gen._create_file_regex("DEBW108", ["o3", "temp", "h2o"]) - test_list = ["DEBW108_o3_test23_test_shuffled.nc", - "DEBW107_o3_test23_test_shuffled.nc", - "DEBW108_o3_test23_test.nc", - "DEBW108_h2o_o3_temp_test_shuffled.nc", - "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc", - "DEBW108_o3_temp_hist9_nboots20_shuffled.nc", - "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"] - assert boot_gen._filter_files(regex, test_list, 10, 10) is None - assert boot_gen._filter_files(regex, test_list, 9, 10) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc" - assert boot_gen._filter_files(regex, test_list, 9, 20) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc" - ### new ### @@ -271,3 +199,89 @@ class TestCreateShuffledData: assert dummy.max() >= res.max() assert dummy.min() <= res.min() assert set(np.unique(res)).issubset({1, 2, 3}) + + +class TestBootStraps: + + @pytest.fixture + def bootstrap(self, orig_generator, data_path): + return BootStraps(orig_generator, data_path, 20) + + @pytest.fixture + @mock.patch("src.data_handling.bootstraps.CreateShuffledData", return_value=None) + def bootstrap_no_shuffling(self, mock_create_shuffle_data, orig_generator, data_path): + shutil.rmtree(data_path) + return BootStraps(orig_generator, data_path, 20) + + def test_init_no_shuffling(self, bootstrap_no_shuffling, data_path): + assert isinstance(bootstrap_no_shuffling, BootStraps) + assert bootstrap_no_shuffling.number_of_bootstraps == 20 + assert bootstrap_no_shuffling.bootstrap_path == data_path + + def test_init_with_shuffling(self, orig_generator, data_path, caplog): + caplog.set_level(logging.INFO) + BootStraps(orig_generator, data_path, 20) + assert caplog.record_tuples[0] == ('root', logging.INFO, "create / check shuffled bootstrap data") + + def test_stations(self, bootstrap_no_shuffling, orig_generator): + assert bootstrap_no_shuffling.stations == orig_generator.stations + + def test_variables(self, bootstrap_no_shuffling, orig_generator): + assert bootstrap_no_shuffling.variables == orig_generator.variables + + def test_window_history_size(self, bootstrap_no_shuffling, orig_generator): + assert bootstrap_no_shuffling.window_history_size == orig_generator.window_history_size + + def test_get_generator(self, bootstrap, orig_generator): + station = bootstrap.stations[0] + var = bootstrap.variables[0] + var_others = bootstrap.variables[1:] + gen = bootstrap.get_generator(station, var) + assert isinstance(gen, BootStrapGenerator) + assert gen.number_of_boots == bootstrap.number_of_bootstraps + assert gen.variables == bootstrap.variables + expected = orig_generator.get_data_generator(station).get_transposed_history() + assert xr.testing.assert_equal(gen.history_orig, expected) is None + assert xr.testing.assert_equal(gen.history, expected.sel(variables=var_others)) is None + assert gen.shuffled.variables == "o3" + + def test_get_labels(self, bootstrap, orig_generator): + station = bootstrap.stations[0] + labels = bootstrap.get_labels(station) + labels_orig = orig_generator.get_data_generator(station).get_transposed_label() + assert labels.shape == (labels_orig.shape[0] * bootstrap.number_of_bootstraps, *labels_orig.shape[1:]) + assert np.testing.assert_array_equal(labels[:labels_orig.shape[0], :], labels_orig.values) is None + + def test_get_orig_prediction(self): + pass + + def test_load_shuffled_data(self): + pass + + def test_get_shuffled_data_file(self, bootstrap): + file_name = bootstrap._get_shuffled_data_file("DEBW107", ["o3"]) + assert file_name == os.path.join(bootstrap.bootstrap_path, "DEBW107_o3_temp_hist7_nboots20_shuffled.nc") + + def test_get_shuffled_data_file_not_found(self, bootstrap_no_shuffling, data_path): + bootstrap_no_shuffling.number_of_boots = 100 + os.makedirs(data_path) + with pytest.raises(FileNotFoundError) as e: + bootstrap_no_shuffling._get_shuffled_data_file("DEBW107", ["o3"]) + assert "Could not find a file to match pattern" in e.value.args[0] + + def test_create_file_regex(self, bootstrap_no_shuffling): + regex = bootstrap_no_shuffling._create_file_regex("DEBW108", ["o3", "temp", "h2o"]) + test_list = ["DEBW108_o3_test23_test_shuffled.nc", + "DEBW107_o3_test23_test_shuffled.nc", + "DEBW108_o3_test23_test.nc", + "DEBW108_h2o_o3_temp_test_shuffled.nc", + "DEBW108_h2o_hum_latent_o3_temp_u_v_test23_test_shuffled.nc", + "DEBW108_o3_temp_hist9_nboots20_shuffled.nc", + "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc"] + f = bootstrap_no_shuffling._filter_files + assert f(regex, test_list, 10, 10) is None + assert f(regex, test_list, 9, 10) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc" + assert f(regex, test_list, 9, 20) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc" + + def test_filter_files(self): + pass