Skip to content
Snippets Groups Projects
Commit 2c4cd969 authored by lukas leufen's avatar lukas leufen
Browse files

removed slowest implementation

parent ae1a3737
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
......@@ -60,9 +60,6 @@ class PostProcessing(RunEnvironment):
with TimeTracking(name="split (refac_1): create_boot_straps_refac_2()"):
self.create_boot_straps_refac_2()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
with TimeTracking(name="split (refac): create_boot_straps_refac()"):
self.create_boot_straps_refac()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
with TimeTracking(name="merged: combined_boot_forecast_and_skill()"):
self.bootstrap_skill_scores = self.combined_boot_forecast_and_skill()
......@@ -73,45 +70,6 @@ class PostProcessing(RunEnvironment):
# plotting
# self.plot()
def create_boot_straps_refac(self):
# forecast
with TimeTracking(name="boot predictions"):
bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general")
number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
for station in bootstraps.stations:
with TimeTracking(name=station):
logging.info(station)
hist, label, station_bootstrap, length = bootstraps.get_generator_station_wise(station)
# make bootstrap predictions
bootstrap_predictions = self.model.predict_generator(generator=station_bootstrap(),
steps=length,
use_multiprocessing=True)
if isinstance(bootstrap_predictions, list):
bootstrap_predictions = bootstrap_predictions[-1]
# get bootstrap prediction meta data
bootstrap_meta = np.array(bootstraps.get_bootstrap_meta_station_wise(station))
# save bootstrap predictions separately for each station and variable combination
variables = np.unique(bootstrap_meta[:, 0])
coords = None
for boot in variables:
# store each variable - station - combination
ind = np.all(bootstrap_meta == [boot, station], axis=1)
length = sum(ind)
sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
coords = (range(length), range(1, window_lead_time + 1))
tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "ahead", "type"])
file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc")
tmp.to_netcdf(file_name)
# store also true labels for each station
labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1))
file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc")
labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"])
labels.to_netcdf(file_name)
def create_boot_straps_refac_2(self):
# forecast
with TimeTracking(name="boot predictions"):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment