diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index fb452702099e46051bd274864cb01a4493bd704a..3e86ed96b70f40ac77c9d3f7df1b774a2f56060d 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -26,13 +26,23 @@ class BootStrapGenerator(keras.utils.Sequence): def __len__(self): return self.number_of_boots - def __getitem__(self, index): + def __getitem__(self, index: int) -> xr.DataArray: + """ + return bootstrapped history for given bootstrap index in same index structure like the original history object + :param index: boot index e [0, nboots-1] + :return: bootstrapped history ready to use + """ logging.debug(f"boot: {index}") boot_hist = self.history.copy() boot_hist = boot_hist.combine_first(self.__get_shuffled(index)) return boot_hist.reindex_like(self.history_orig) - def __get_shuffled(self, index): + def __get_shuffled(self, index: int) -> xr.DataArray: + """ + returns shuffled data for given boot index from shuffled attribute + :param index: boot index e [0, nboots-1] + :return: shuffled data + """ shuffled_var = self.shuffled.sel(boots=index).expand_dims("variables").drop("boots") return shuffled_var.transpose("datetime", "window", "Stations", "variables") @@ -139,6 +149,10 @@ class BootStraps: def variables(self): return self.data.variables + @property + def window_history_size(self): + return self.data.window_history_size + def get_generator(self, station, var): """ This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return @@ -146,8 +160,8 @@ class BootStraps: :return: """ hist, _ = self.data[station] - shuffled_data = self._load_shuffled_data(station, self.variables) - return RealBootStrapGenerator(self.number_of_bootstraps, hist, shuffled_data, self.variables, var) + shuffled_data = self._load_shuffled_data(station, self.variables).reindex_like(hist) + return BootStrapGenerator(self.number_of_bootstraps, hist, shuffled_data, self.variables, var) def get_labels(self, key: Union[str, int]): """ @@ -187,7 +201,7 @@ class BootStraps: def _get_shuffled_data_file(self, station, variables): files = os.listdir(self.bootstrap_path) regex = self._create_file_regex(station, variables) - file = self._filter_files(regex, files, self.data.window_history_size, self.number_of_bootstraps) + file = self._filter_files(regex, files, self.window_history_size, self.number_of_bootstraps) if file: return os.path.join(self.bootstrap_path, file) else: