diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 2700cb6c0aaff7227bfea78116a5417aeae9b201..6ac33e2a2555fe3f253593423e9e71e0aa97f4af 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -22,17 +22,22 @@ class BootStraps(RunEnvironment): self.create_shuffled_data() def create_shuffled_data(self): + """ + Create shuffled data. Use original test data, add dimension 'boots' with length number of bootstraps and insert + randomly selected variables. If there is a suitable local file for requested window size and number of + bootstraps, no additional file will be created inside this function. + """ variables_str = '_'.join(sorted(self.test_data.variables)) window = self.test_data.window_history_size for station in self.test_data.stations: - valid, max_nboot = self.valid_bootstrap_file(station, variables_str, window) + valid, nboot = self.valid_bootstrap_file(station, variables_str, window) if not valid: logging.info(f'create bootstap data for {station}') hist, _ = self.test_data[station] data = hist.copy() - file_name = f"{station}_{variables_str}_hist{window}_nboots{max_nboot}_shuffled.nc" + file_name = f"{station}_{variables_str}_hist{window}_nboots{nboot}_shuffled.nc" file_path = os.path.join(self.bootstrap_path, file_name) - data = data.expand_dims({'boots': range(max_nboot)}, axis=-1) + data = data.expand_dims({'boots': range(nboot)}, axis=-1) shuffled_variable = np.full(data.shape, np.nan) for i, var in enumerate(data.coords['variables']): single_variable = data.sel(variables=var).values @@ -41,25 +46,37 @@ class BootStraps(RunEnvironment): shuffled_data.to_netcdf(file_path) def valid_bootstrap_file(self, station, variables, window): - str_re = re.compile(f"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*") - dir_list = os.listdir(self.bootstrap_path) + """ + Compare local bootstrap file with given settings for station, variables, window and number of bootstraps. If a + match was found, this method returns a tuple (True, None). In any other case, it returns (False, max_nboot), + where max_nboot is the highest boot number found in the local storage. A match is defined so that the window + length is ge than given window size form args and the number of boots is also ge than the given number of boots + from this class. Furthermore, this functions deletes local files, if the match the station pattern but don't fit + the window and bootstrap condition. This is performed, because it is assumed, that the corresponding file will + be created with a longer or at least same window size and numbers of bootstraps. + :param station: + :param variables: + :param window: + :return: + """ + regex = re.compile(rf"{station}_{variables}_hist(\d+)_nboots(\d+)_shuffled*") max_nboot = self.number_bootstraps - for file in dir_list: - match = str_re.match(file) + for file in os.listdir(self.bootstrap_path): + match = regex.match(file) if match: - window_existing = int(match.group(1)) - nboot_existing = int(match.group(2)) - max_nboot = max([max_nboot, nboot_existing]) - if (window_existing >= window) and (nboot_existing >= self.number_bootstraps): + window_file = int(match.group(1)) + nboot_file = int(match.group(2)) + max_nboot = max([max_nboot, nboot_file]) + if (window_file >= window) and (nboot_file >= self.number_bootstraps): return True, None else: os.remove(os.path.join(self.bootstrap_path, file)) return False, max_nboot - def shuffle_single_variable(self, data): + @staticmethod + def shuffle_single_variable(data: np.ndarray) -> np.ndarray: orig_shape = data.shape size = orig_shape - # size = (*orig_shape, self.number_bootstraps) return np.random.choice(data.reshape(-1,), size=size) diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index ed572161e535cc24332a76f4cbce86403eb1d067..c1edd7ca7f012ccdebc2c75119eb37c5bc56c125 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -4,20 +4,27 @@ from src.data_handling.bootstraps import BootStraps import pytest import os +import numpy as np + class TestBootstraps: @pytest.fixture - def boot_no_init(self): + def path(self): + path = os.path.join(os.path.dirname(__file__), "data") + if not os.path.exists(path): + os.makedirs(path) + return path + + @pytest.fixture + def boot_no_init(self, path): obj = object.__new__(BootStraps) super(BootStraps, obj).__init__() obj.number_bootstraps = 50 + obj.bootstrap_path = path return obj - def test_valid_bootstrap_file(self, boot_no_init): - path = os.path.join(os.path.dirname(__file__), "data") - os.makedirs(path) - boot_no_init.bootstrap_path = path + def test_valid_bootstrap_file(self, path, boot_no_init): station = "TESTSTATION" variables = "var1_var2_var3" window = 5 @@ -44,3 +51,14 @@ class TestBootstraps: os.mknod(os.path.join(path, f"{station}_{variables}_hist5_nboots60_shuffled.dat")) boot_no_init.number_bootstraps = 50 assert boot_no_init.valid_bootstrap_file(station, variables, 20) == (False, 60) + + def test_shuffle_single_variale(self, boot_no_init): + data = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) + res = boot_no_init.shuffle_single_variable(data) + assert res.shape == data.shape + assert res.max() == data.max() + assert res.min() == data.min() + assert set(np.unique(res)).issubset({1, 2, 3}) + + def test_create_shuffled_data(self): + pass \ No newline at end of file