Skip to content
Snippets Groups Projects
Commit 116edaa2 authored by lukas leufen's avatar lukas leufen
Browse files

add function to extract shuffled data (if locally stored data exceeds current dimensions)

parent 7722820f
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #33452 failed
...@@ -26,13 +26,23 @@ class BootStrapGenerator(keras.utils.Sequence): ...@@ -26,13 +26,23 @@ class BootStrapGenerator(keras.utils.Sequence):
def __len__(self): def __len__(self):
return self.number_of_boots return self.number_of_boots
def __getitem__(self, index): def __getitem__(self, index: int) -> xr.DataArray:
"""
return bootstrapped history for given bootstrap index in same index structure like the original history object
:param index: boot index e [0, nboots-1]
:return: bootstrapped history ready to use
"""
logging.debug(f"boot: {index}") logging.debug(f"boot: {index}")
boot_hist = self.history.copy() boot_hist = self.history.copy()
boot_hist = boot_hist.combine_first(self.__get_shuffled(index)) boot_hist = boot_hist.combine_first(self.__get_shuffled(index))
return boot_hist.reindex_like(self.history_orig) return boot_hist.reindex_like(self.history_orig)
def __get_shuffled(self, index): def __get_shuffled(self, index: int) -> xr.DataArray:
"""
returns shuffled data for given boot index from shuffled attribute
:param index: boot index e [0, nboots-1]
:return: shuffled data
"""
shuffled_var = self.shuffled.sel(boots=index).expand_dims("variables").drop("boots") shuffled_var = self.shuffled.sel(boots=index).expand_dims("variables").drop("boots")
return shuffled_var.transpose("datetime", "window", "Stations", "variables") return shuffled_var.transpose("datetime", "window", "Stations", "variables")
...@@ -139,6 +149,10 @@ class BootStraps: ...@@ -139,6 +149,10 @@ class BootStraps:
def variables(self): def variables(self):
return self.data.variables return self.data.variables
@property
def window_history_size(self):
return self.data.window_history_size
def get_generator(self, station, var): def get_generator(self, station, var):
""" """
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
...@@ -146,8 +160,8 @@ class BootStraps: ...@@ -146,8 +160,8 @@ class BootStraps:
:return: :return:
""" """
hist, _ = self.data[station] hist, _ = self.data[station]
shuffled_data = self._load_shuffled_data(station, self.variables) shuffled_data = self._load_shuffled_data(station, self.variables).reindex_like(hist)
return RealBootStrapGenerator(self.number_of_bootstraps, hist, shuffled_data, self.variables, var) return BootStrapGenerator(self.number_of_bootstraps, hist, shuffled_data, self.variables, var)
def get_labels(self, key: Union[str, int]): def get_labels(self, key: Union[str, int]):
""" """
...@@ -187,7 +201,7 @@ class BootStraps: ...@@ -187,7 +201,7 @@ class BootStraps:
def _get_shuffled_data_file(self, station, variables): def _get_shuffled_data_file(self, station, variables):
files = os.listdir(self.bootstrap_path) files = os.listdir(self.bootstrap_path)
regex = self._create_file_regex(station, variables) regex = self._create_file_regex(station, variables)
file = self._filter_files(regex, files, self.data.window_history_size, self.number_of_bootstraps) file = self._filter_files(regex, files, self.window_history_size, self.number_of_bootstraps)
if file: if file:
return os.path.join(self.bootstrap_path, file) return os.path.join(self.bootstrap_path, file)
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment