diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 21fc23d8a072867dbd28b2ede7255325db206164..2700cb6c0aaff7227bfea78116a5417aeae9b201 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -25,7 +25,7 @@ class BootStraps(RunEnvironment): variables_str = '_'.join(sorted(self.test_data.variables)) window = self.test_data.window_history_size for station in self.test_data.stations: - valid, _, max_nboot = self.valid_bootstrap_file(station, variables_str, window) + valid, max_nboot = self.valid_bootstrap_file(station, variables_str, window) if not valid: logging.info(f'create bootstap data for {station}') hist, _ = self.test_data[station] @@ -44,23 +44,17 @@ class BootStraps(RunEnvironment): str_re = re.compile(f"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*") dir_list = os.listdir(self.bootstrap_path) max_nboot = self.number_bootstraps - max_window = self.number_bootstraps for file in dir_list: match = str_re.match(file) if match: window_existing = int(match.group(1)) nboot_existing = int(match.group(2)) - max_window = max([max_window, window_existing]) max_nboot = max([max_nboot, nboot_existing]) if (window_existing >= window) and (nboot_existing >= self.number_bootstraps): - return True, 0, 0 + return True, None else: os.remove(os.path.join(self.bootstrap_path, file)) - return False, max_window, max_nboot - - - - + return False, max_nboot def shuffle_single_variable(self, data): orig_shape = data.shape @@ -69,7 +63,6 @@ class BootStraps(RunEnvironment): return np.random.choice(data.reshape(-1,), size=size) - if __name__ == "__main__": from src.run_modules.experiment_setup import ExperimentSetup diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py new file mode 100644 index 0000000000000000000000000000000000000000..ed572161e535cc24332a76f4cbce86403eb1d067 --- /dev/null +++ b/test/test_data_handling/test_bootstraps.py @@ -0,0 +1,46 @@ + +from src.data_handling.bootstraps import BootStraps + +import pytest +import os + + +class TestBootstraps: + + @pytest.fixture + def boot_no_init(self): + obj = object.__new__(BootStraps) + super(BootStraps, obj).__init__() + obj.number_bootstraps = 50 + return obj + + def test_valid_bootstrap_file(self, boot_no_init): + path = os.path.join(os.path.dirname(__file__), "data") + os.makedirs(path) + boot_no_init.bootstrap_path = path + station = "TESTSTATION" + variables = "var1_var2_var3" + window = 5 + # empty case + assert len(os.listdir(path)) == 0 + assert boot_no_init.valid_bootstrap_file(station, variables, window) == (False, 50) + # different cases, where files with bigger range are existing + os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots50_shuffled.dat")) + assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None) + os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots100_shuffled.dat")) + assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None) + os.mknod(os.path.join(path, f"{station}_{variables}_hist10_nboots50_shuffled.dat")) + os.mknod(os.path.join(path, f"{station}1_{variables}_hist10_nboots50_shuffled.dat")) + assert boot_no_init.valid_bootstrap_file(station, variables, window) == (True, None) + # need to reload data and therefore remove not fitting files for this station + assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 100) + assert len(os.listdir(path)) == 1 + # reload because expanded boot number + os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots50_shuffled.dat")) + boot_no_init.number_bootstraps = 60 + assert boot_no_init.valid_bootstrap_file(station, variables, window) == (False, 60) + assert len(os.listdir(path)) == 1 + # reload because of expanded window size, but use maximum boot number from file names + os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots60_shuffled.dat")) + boot_no_init.number_bootstraps = 50 + assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 60)