From 3eb4745ca4c6d96052a8e67f2f212b44cf7457db Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 26 Feb 2020 15:04:52 +0100
Subject: [PATCH] save orig labels locally

---
 src/data_handling/bootstraps.py    | 15 +++++++++++++++
 src/run_modules/post_processing.py | 12 +++++++++---
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py
index 998ed8c6..60fc55fb 100644
--- a/src/data_handling/bootstraps.py
+++ b/src/data_handling/bootstraps.py
@@ -31,6 +31,11 @@ class BootStrapGenerator:
         """
         return len(self.orig_generator)*self.boots*len(self.variables)
 
+    def get_labels(self):
+        for (_, label) in self.orig_generator:
+            for _ in range(self.boots):
+                yield label
+
     def get_generator(self):
         """
         This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return
@@ -85,6 +90,16 @@ class BootStraps(RunEnvironment):
     def get_boot_strap_generator_length(self):
         return self._boot_strap_generator.__len__()
 
+    def get_labels(self):
+        labels_list = []
+        chunks = None
+        for labels in self._boot_strap_generator.get_labels():
+            if len(labels_list) == 0:
+                chunks = (100, labels.data.shape[1])
+            labels_list.append(da.from_array(labels.data, chunks=chunks))
+        labels_out = da.concatenate(labels_list, axis=0)
+        return labels_out.compute()
+
     def get_chunk_size(self):
         hist, _ = self.data[0]
         return (100, *hist.shape[1:], self.number_bootstraps)
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 8a0df437..97f06812 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -49,15 +49,15 @@ class PostProcessing(RunEnvironment):
             self.make_prediction()
             logging.info("take a look on the next reported time measure. If this increases a lot, one should think to "
                          "skip make_prediction() whenever it is possible to save time.")
-        self.skill_scores = self.calculate_skill_scores()
-        self.plot()
+        # self.skill_scores = self.calculate_skill_scores()
+        # self.plot()
         self.create_boot_straps()
 
     def create_boot_straps(self):
         bootstrap_path = self.data_store.get("bootstrap_path", "general")
         forecast_path = self.data_store.get("forecast_path", "general")
         window_lead_time = self.data_store.get("window_lead_time", "general")
-        bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
+        bootstraps = BootStraps(self.test_data, bootstrap_path, 2)
         with TimeTracking(name="boot predictions"):
             bootstrap_predictions = self.model.predict_generator(generator=bootstraps.boot_strap_generator(),
                                                                  steps=bootstraps.get_boot_strap_generator_length())
@@ -68,8 +68,14 @@ class PostProcessing(RunEnvironment):
             ind = (bootstrap_meta == boot)
             sel = bootstrap_predictions[ind].reshape((length, window_lead_time, 1))
             tmp = xr.DataArray(sel, coords=(range(length), range(window_lead_time), [boot]), dims=["index", "window", "boot"])
+            logging.info(tmp.shape)
             file_name = os.path.join(forecast_path, f"bootstraps_{boot}.nc")
             tmp.to_netcdf(file_name)
+        labels = bootstraps.get_labels().reshape((length, window_lead_time, 1))
+        file_name = os.path.join(forecast_path, f"bootstraps_orig.nc")
+        orig = xr.DataArray(labels, coords=(range(length), range(window_lead_time), ["orig"]), dims=["index", "window", "boot"])
+        logging.info(orig.shape)
+        orig.to_netcdf(file_name)
 
     def _load_model(self):
         try:
-- 
GitLab