diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 9247e6e19558fe29b9b8b6d2a0b0fe2cc2054a91..8690785659ab256fc78b4cfe8701461f67236a9b 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -18,60 +18,49 @@ class BootStrapGenerator: def __init__(self, orig_generator, boots, chunksize, bootstrap_path): self.orig_generator: DataGenerator = orig_generator self.stations = self.orig_generator.stations + self.variables = self.orig_generator.variables self.boots = boots self.chunksize = chunksize self.bootstrap_path = bootstrap_path self._iterator = 0 - self.__next__() - a = 1 def __len__(self): """ display the number of stations """ - return len(self.orig_generator)*self.boots + return len(self.orig_generator)*self.boots*len(self.variables) - def __iter__(self): - """ - Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute - `_iterator` to 0. - :return: - """ - self._iterator = 0 - return self + # def __iter__(self): + # """ + # Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute + # `_iterator` to 0. + # :return: + # """ + # self._iterator = 0 + # return self - def __next__(self): + def __iter__(self): """ This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return the history and label data of this generator. :return: """ - if self._iterator < self.__len__(): + while True: for i, data in enumerate(self.orig_generator): station = self.orig_generator.get_station_key(i) + logging.info(f"station: {station}") hist, label = data shuffled_data = self.load_boot_data(station) - all_variables = self.orig_generator.variables - for var in all_variables: + for var in self.variables: + logging.info(f" var: {var}") for boot in range(self.boots): - boot_hist: xr.DataArray = hist - boot_hist = boot_hist.sel(variables=helpers.list_pop(all_variables, var)) + logging.debug(f"boot: {boot}") + boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var)) shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables") boot_hist = boot_hist.combine_first(shuffled_var) - boot_hist.sortby("variables") - # boot_hist - - - - - # self._iterator += 1 - # if data.history is not None and data.label is not None: # pragma: no branch - # return data.history.transpose("datetime", "window", "Stations", "variables"), \ - # data.label.squeeze("Stations").transpose("datetime", "window") - else: - self.__next__() # pragma: no cover - else: - raise StopIteration + boot_hist = boot_hist.sortby("variables") + yield boot_hist, label + return def load_boot_data(self, station): files = os.listdir(self.bootstrap_path) @@ -88,11 +77,13 @@ class BootStraps(RunEnvironment): super().__init__() self.test_data: DataGenerator = self.data_store.get("generator", "general.test") - self.number_bootstraps = 50 + self.number_bootstraps = 10 self.bootstrap_path = self.data_store.get("bootstrap_path", "general") self.chunks = self.get_chunk_size() self.create_shuffled_data() - BootStrapGenerator(self.test_data, self.number_bootstraps, self.chunks, self.bootstrap_path) + bsg =BootStrapGenerator(self.test_data, self.number_bootstraps, self.chunks, self.bootstrap_path) + for bs in bsg: + hist, label = bs def get_chunk_size(self): hist, _ = self.test_data[0]