Skip to content
Snippets Groups Projects
Commit b568c246 authored by lukas leufen's avatar lukas leufen
Browse files

small cleanup

parent c271db81
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #32464 passed
......@@ -25,66 +25,6 @@ class BootStrapGenerator:
def __len__(self):
return len(self.orig_generator) * self.number_of_boots * len(self.variables)
def get_generator(self):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
while True:
for i, data in enumerate(self.orig_generator):
station = self.orig_generator.get_station_key(i)
logging.info(f"station: {station}")
hist, label = data
shuffled_data = self.load_shuffled_data(station, self.variables)
for var in self.variables:
logging.debug(f" var: {var}")
for boot in range(self.number_of_boots):
logging.debug(f"boot: {boot}")
boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
boot_hist = boot_hist.combine_first(shuffled_var)
boot_hist = boot_hist.sortby("variables")
yield boot_hist, label
return
def get_generator_station_wise(self, station):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
the history and label data of this generator.
:return:
"""
# logging.info(f"station: {station}")
hist, label = self.orig_generator[station]
shuffled_data = self.load_shuffled_data(station, self.variables)
def f():
while True:
for var in self.variables:
logging.debug(f" var: {var}")
for boot in range(self.number_of_boots):
logging.debug(f"boot: {boot}")
boot_hist = hist.sel(variables=helpers.list_pop(self.variables, var))
shuffled_var = shuffled_data.sel(variables=var, boots=boot).expand_dims("variables").drop("boots").transpose("datetime", "window", "Stations", "variables")
boot_hist = boot_hist.combine_first(shuffled_var)
boot_hist = boot_hist.sortby("variables")
yield boot_hist
return
return hist, label, f, self.number_of_boots * len(self.variables)
def get_bootstrap_meta_station_wise(self, station) -> List:
"""
Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
:return: list with bootstrapped variable first and its corresponding station second.
"""
bootstrap_meta = []
label = self.orig_generator.get_data_generator(station).get_transposed_label()
for var in self.variables:
for boot in range(self.number_of_boots):
bootstrap_meta.extend([[var, station]] * len(label))
return bootstrap_meta
def get_generator_station_var_wise(self, station, var):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
......@@ -118,19 +58,6 @@ class BootStrapGenerator:
bootstrap_meta.extend([[var, station]] * len(label))
return bootstrap_meta
def get_bootstrap_meta(self) -> List:
"""
Create meta data on ordering of variable bootstraps according to ordering from get_generator method.
:return: list with bootstrapped variable first and its corresponding station second.
"""
bootstrap_meta = []
for station in self.stations:
label = self.orig_generator.get_data_generator(station).get_transposed_label()
for var in self.variables:
for boot in range(self.number_of_boots):
bootstrap_meta.extend([[var, station]] * len(label))
return bootstrap_meta
def get_labels(self, key: Union[str, int]):
"""
Reepats labels for given key by the number of boots and yield it one by one.
......@@ -194,24 +121,12 @@ class BootStraps:
def variables(self):
return self._boot_strap_generator.variables
def get_generator_station_wise(self, station):
return self._boot_strap_generator.get_generator_station_wise(station)
def get_generator_station_var_wise(self, station, var):
return self._boot_strap_generator.get_generator_station_var_wise(station, var)
def get_bootstrap_meta_station_wise(self, station):
return self._boot_strap_generator.get_bootstrap_meta_station_wise(station)
def get_bootstrap_meta_station_var_wise(self, station, var):
return self._boot_strap_generator.get_bootstrap_meta_station_var_wise(station, var)
def get_boot_strap_meta(self):
return self._boot_strap_generator.get_bootstrap_meta()
def boot_strap_generator(self):
return self._boot_strap_generator.get_generator()
def get_boot_strap_generator_length(self):
return self._boot_strap_generator.__len__()
......
......@@ -80,7 +80,6 @@ class PostProcessing(RunEnvironment):
logging.info("Couldn't load all files, restart bootstrap postprocessing with create_new_bootstraps=True.")
self.bootstrap_postprocessing(True, _iter=1)
def create_boot_straps(self):
# forecast
with TimeTracking(name="boot predictions"):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment