diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 998ed8c6990d6a16388874c52faf4931ef4ba174..60fc55fbb00a5ec42c1ee8e06c17d848106a08dc 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -31,6 +31,11 @@ class BootStrapGenerator: """ return len(self.orig_generator)*self.boots*len(self.variables) + def get_labels(self): + for (_, label) in self.orig_generator: + for _ in range(self.boots): + yield label + def get_generator(self): """ This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return @@ -85,6 +90,16 @@ class BootStraps(RunEnvironment): def get_boot_strap_generator_length(self): return self._boot_strap_generator.__len__() + def get_labels(self): + labels_list = [] + chunks = None + for labels in self._boot_strap_generator.get_labels(): + if len(labels_list) == 0: + chunks = (100, labels.data.shape[1]) + labels_list.append(da.from_array(labels.data, chunks=chunks)) + labels_out = da.concatenate(labels_list, axis=0) + return labels_out.compute() + def get_chunk_size(self): hist, _ = self.data[0] return (100, *hist.shape[1:], self.number_bootstraps) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 8a0df43756acfdba0c86be7eecc8e4da37999ce3..97f068123bfb967ab34789952a296f1a8632a516 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -49,15 +49,15 @@ class PostProcessing(RunEnvironment): self.make_prediction() logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") - self.skill_scores = self.calculate_skill_scores() - self.plot() + # self.skill_scores = self.calculate_skill_scores() + # self.plot() self.create_boot_straps() def create_boot_straps(self): bootstrap_path = self.data_store.get("bootstrap_path", "general") forecast_path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") - bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + bootstraps = BootStraps(self.test_data, bootstrap_path, 2) with TimeTracking(name="boot predictions"): bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), steps=bootstraps.get_boot_strap_generator_length()) @@ -68,8 +68,14 @@ class PostProcessing(RunEnvironment): ind = (bootstrap_meta == boot) sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"]) + logging.info(tmp.shape) file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc") tmp.to_netcdf(file_name) + labels = bootstraps.get_labels().reshape((length, window_lead_time, 1)) + file_name = os.path.join(forecast_path, f"bootstraps_orig.nc") + orig = xr.DataArray(labels, coords=(range(length), range(window_lead_time), ["orig"]), dims=["index", "window", "boot"]) + logging.info(orig.shape) + orig.to_netcdf(file_name) def _load_model(self): try: