Skip to content
Snippets Groups Projects
Commit 116edaa2 authored by lukas leufen's avatar lukas leufen
Browse files

add function to extract shuffled data (if locally stored data exceeds current dimensions)

parent 7722820f
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #33452 failed
......@@ -26,13 +26,23 @@ class BootStrapGenerator(keras.utils.Sequence):
def __len__(self):
return self.number_of_boots
def __getitem__(self, index):
def __getitem__(self, index: int) -> xr.DataArray:
"""
return bootstrapped history for given bootstrap index in same index structure like the original history object
:param index: boot index e [0, nboots-1]
:return: bootstrapped history ready to use
"""
logging.debug(f"boot: {index}")
boot_hist = self.history.copy()
boot_hist = boot_hist.combine_first(self.__get_shuffled(index))
return boot_hist.reindex_like(self.history_orig)
def __get_shuffled(self, index):
def __get_shuffled(self, index: int) -> xr.DataArray:
"""
returns shuffled data for given boot index from shuffled attribute
:param index: boot index e [0, nboots-1]
:return: shuffled data
"""
shuffled_var = self.shuffled.sel(boots=index).expand_dims("variables").drop("boots")
return shuffled_var.transpose("datetime", "window", "Stations", "variables")
......@@ -139,6 +149,10 @@ class BootStraps:
def variables(self):
return self.data.variables
@property
def window_history_size(self):
return self.data.window_history_size
def get_generator(self, station, var):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
......@@ -146,8 +160,8 @@ class BootStraps:
:return:
"""
hist, _ = self.data[station]
shuffled_data = self._load_shuffled_data(station, self.variables)
return RealBootStrapGenerator(self.number_of_bootstraps, hist, shuffled_data, self.variables, var)
shuffled_data = self._load_shuffled_data(station, self.variables).reindex_like(hist)
return BootStrapGenerator(self.number_of_bootstraps, hist, shuffled_data, self.variables, var)
def get_labels(self, key: Union[str, int]):
"""
......@@ -187,7 +201,7 @@ class BootStraps:
def _get_shuffled_data_file(self, station, variables):
files = os.listdir(self.bootstrap_path)
regex = self._create_file_regex(station, variables)
file = self._filter_files(regex, files, self.data.window_history_size, self.number_of_bootstraps)
file = self._filter_files(regex, files, self.window_history_size, self.number_of_bootstraps)
if file:
return os.path.join(self.bootstrap_path, file)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment