From 3eb4745ca4c6d96052a8e67f2f212b44cf7457db Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 26 Feb 2020 15:04:52 +0100 Subject: [PATCH] save orig labels locally --- src/data_handling/bootstraps.py | 15 +++++++++++++++ src/run_modules/post_processing.py | 12 +++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 998ed8c6..60fc55fb 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -31,6 +31,11 @@ class BootStrapGenerator: """ return len(self.orig_generator)*self.boots*len(self.variables) + def get_labels(self): + for (_, label) in self.orig_generator: + for _ in range(self.boots): + yield label + def get_generator(self): """ This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return @@ -85,6 +90,16 @@ class BootStraps(RunEnvironment): def get_boot_strap_generator_length(self): return self._boot_strap_generator.__len__() + def get_labels(self): + labels_list = [] + chunks = None + for labels in self._boot_strap_generator.get_labels(): + if len(labels_list) == 0: + chunks = (100, labels.data.shape[1]) + labels_list.append(da.from_array(labels.data, chunks=chunks)) + labels_out = da.concatenate(labels_list, axis=0) + return labels_out.compute() + def get_chunk_size(self): hist, _ = self.data[0] return (100, *hist.shape[1:], self.number_bootstraps) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 8a0df437..97f06812 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -49,15 +49,15 @@ class PostProcessing(RunEnvironment): self.make_prediction() logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") - self.skill_scores = self.calculate_skill_scores() - self.plot() + # self.skill_scores = self.calculate_skill_scores() + # self.plot() self.create_boot_straps() def create_boot_straps(self): bootstrap_path = self.data_store.get("bootstrap_path", "general") forecast_path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") - bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + bootstraps = BootStraps(self.test_data, bootstrap_path, 2) with TimeTracking(name="boot predictions"): bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(), steps=bootstraps.get_boot_strap_generator_length()) @@ -68,8 +68,14 @@ class PostProcessing(RunEnvironment): ind = (bootstrap_meta == boot) sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1)) tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"]) + logging.info(tmp.shape) file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc") tmp.to_netcdf(file_name) + labels = bootstraps.get_labels().reshape((length, window_lead_time, 1)) + file_name = os.path.join(forecast_path, f"bootstraps_orig.nc") + orig = xr.DataArray(labels, coords=(range(length), range(window_lead_time), ["orig"]), dims=["index", "window", "boot"]) + logging.info(orig.shape) + orig.to_netcdf(file_name) def _load_model(self): try: -- GitLab