Skip to content
Snippets Groups Projects
Commit 3eb4745c authored by lukas leufen's avatar lukas leufen
Browse files

save orig labels locally

parent cf202273
Branches
Tags
2 merge requests!59Develop,!52implemented bootstraps
Pipeline #30529 passed
......@@ -31,6 +31,11 @@ class BootStrapGenerator:
"""
return len(self.orig_generator)*self.boots*len(self.variables)
def get_labels(self):
for (_, label) in self.orig_generator:
for _ in range(self.boots):
yield label
def get_generator(self):
"""
This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
......@@ -85,6 +90,16 @@ class BootStraps(RunEnvironment):
def get_boot_strap_generator_length(self):
return self._boot_strap_generator.__len__()
def get_labels(self):
labels_list = []
chunks = None
for labels in self._boot_strap_generator.get_labels():
if len(labels_list) == 0:
chunks = (100, labels.data.shape[1])
labels_list.append(da.from_array(labels.data, chunks=chunks))
labels_out = da.concatenate(labels_list, axis=0)
return labels_out.compute()
def get_chunk_size(self):
hist, _ = self.data[0]
return (100, *hist.shape[1:], self.number_bootstraps)
......
......@@ -49,15 +49,15 @@ class PostProcessing(RunEnvironment):
self.make_prediction()
logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
"skip make_prediction() whenever it is possible to save time.")
self.skill_scores = self.calculate_skill_scores()
self.plot()
# self.skill_scores = self.calculate_skill_scores()
# self.plot()
self.create_boot_straps()
def create_boot_straps(self):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
bootstraps = BootStraps(self.test_data, bootstrap_path, 2)
with TimeTracking(name="boot predictions"):
bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
steps=bootstraps.get_boot_strap_generator_length())
......@@ -68,8 +68,14 @@ class PostProcessing(RunEnvironment):
ind = (bootstrap_meta == boot)
sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"])
logging.info(tmp.shape)
file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc")
tmp.to_netcdf(file_name)
labels = bootstraps.get_labels().reshape((length, window_lead_time, 1))
file_name = os.path.join(forecast_path, f"bootstraps_orig.nc")
orig = xr.DataArray(labels, coords=(range(length), range(window_lead_time), ["orig"]), dims=["index", "window", "boot"])
logging.info(orig.shape)
orig.to_netcdf(file_name)
def _load_model(self):
try:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment