Skip to content
Snippets Groups Projects
Commit 79fb9653 authored by lukas leufen's avatar lukas leufen
Browse files

removed slow create_boot_straps() implementation

parent 2d3f6033
Branches
Tags
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #32035 passed
......@@ -55,17 +55,14 @@ class PostProcessing(RunEnvironment):
if self.data_store.get("evaluate_bootstraps", "general.postprocessing"):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
BootStraps(self.test_data, bootstrap_path, 20)
with TimeTracking(name="split (refac_1)"):
with TimeTracking(name="split (refac_1): create_boot_straps_refac_2()"):
self.create_boot_straps_refac_2()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
with TimeTracking(name="split (refac)"):
with TimeTracking(name="split (refac): create_boot_straps_refac()"):
self.create_boot_straps_refac()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
with TimeTracking(name="merged"):
with TimeTracking(name="merged: combined_boot_forecast_and_skill()"):
self.bootstrap_skill_scores = self.combined_boot_forecast_and_skill()
with TimeTracking(name="original version"):
self.create_boot_straps()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
# skill scores
......@@ -74,43 +71,6 @@ class PostProcessing(RunEnvironment):
# plotting
# self.plot()
def create_boot_straps(self):
# forecast
with TimeTracking(name="boot predictions"):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
# make bootstrap predictions
logging.info("predictions")
bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
steps=bootstraps.get_boot_strap_generator_length(),
use_multiprocessing=True)
if isinstance(bootstrap_predictions, list):
bootstrap_predictions = bootstrap_predictions[-1]
# get bootstrap prediction meta data
bootstrap_meta = np.array(bootstraps.get_boot_strap_meta())
# save bootstrap predictions separately for each station and variable combination
variables = np.unique(bootstrap_meta[:, 0])
for station in np.unique(bootstrap_meta[:, 1]):
logging.info(station)
coords = None
for boot in variables:
# store each variable - station - combination
ind = np.all(bootstrap_meta == [boot, station], axis=1)
length = sum(ind)
sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
coords = (range(length), range(1, window_lead_time + 1))
tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "ahead", "type"])
file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc")
tmp.to_netcdf(file_name)
# store also true labels for each station
labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1))
file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
labels.to_netcdf(file_name)
def create_boot_straps_refac(self):
# forecast
with TimeTracking(name="boot predictions"):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment