diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 3389c659facf88383330ec0543760ce99b127aec..9b5ab8484a1252e44f8f7cf264b5afd6082cc954 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -55,17 +55,14 @@ class PostProcessing(RunEnvironment): if self.data_store.get("evaluate_bootstraps", "general.postprocessing"): bootstrap_path = self.data_store.get("bootstrap_path", "general") BootStraps(self.test_data, bootstrap_path, 20) - with TimeTracking(name="split (refac_1)"): + with TimeTracking(name="split (refac_1): create_boot_straps_refac_2()"): self.create_boot_straps_refac_2() self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores() - with TimeTracking(name="split (refac)"): + with TimeTracking(name="split (refac): create_boot_straps_refac()"): self.create_boot_straps_refac() self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores() - with TimeTracking(name="merged"): + with TimeTracking(name="merged: combined_boot_forecast_and_skill()"): self.bootstrap_skill_scores = self.combined_boot_forecast_and_skill() - with TimeTracking(name="original version"): - self.create_boot_straps() - self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores() # skill scores @@ -74,43 +71,6 @@ class PostProcessing(RunEnvironment): # plotting # self.plot() - def create_boot_straps(self): - # forecast - with TimeTracking(name="boot predictions"): - bootstrap_path = self.data_store.get("bootstrap_path", "general") - forecast_path = self.data_store.get("forecast_path", "general") - window_lead_time = self.data_store.get("window_lead_time", "general") - bootstraps = BootStraps(self.test_data, bootstrap_path, 20) - # make bootstrap predictions - logging.info("predictions") - bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), - steps=bootstraps.get_boot_strap_generator_length(), - use_multiprocessing=True) - - if isinstance(bootstrap_predictions, list): - bootstrap_predictions = bootstrap_predictions[-1] - # get bootstrap prediction meta data - bootstrap_meta = np.array(bootstraps.get_boot_strap_meta()) - # save bootstrap predictions separately for each station and variable combination - variables = np.unique(bootstrap_meta[:, 0]) - for station in np.unique(bootstrap_meta[:, 1]): - logging.info(station) - coords = None - for boot in variables: - # store each variable - station - combination - ind = np.all(bootstrap_meta == [boot, station], axis=1) - length = sum(ind) - sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) - coords = (range(length), range(1, window_lead_time + 1)) - tmp = xr.DataArray(sel, coords=(*coords, [boot]), dims=["index", "ahead", "type"]) - file_name = os.path.join(forecast_path, f"bootstraps_{boot}_{station}.nc") - tmp.to_netcdf(file_name) - # store also true labels for each station - labels = bootstraps.get_labels(station).reshape((length, window_lead_time, 1)) - file_name = os.path.join(forecast_path, f"bootstraps_labels_{station}.nc") - labels = xr.DataArray(labels, coords=(*coords, ["obs"]), dims=["index", "ahead", "type"]) - labels.to_netcdf(file_name) - def create_boot_straps_refac(self): # forecast with TimeTracking(name="boot predictions"):