From 3460756ad8fb7e1698231cc63a850740b7422dd3 Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Wed, 18 Mar 2020 15:59:33 +0100
Subject: [PATCH] number of bootstraps is now a parameter from experiment setup

---
 src/run_modules/experiment_setup.py |  3 ++-
 src/run_modules/post_processing.py  | 18 +++++++++++-------
 2 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py
index 53ef2c61..a420e287 100644
--- a/src/run_modules/experiment_setup.py
+++ b/src/run_modules/experiment_setup.py
@@ -37,7 +37,7 @@ class ExperimentSetup(RunEnvironment):
                  test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
                  experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
                  create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
-                 evaluate_bootstraps=True, plot_list=None):
+                 evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None):
 
         # create run framework
         super().__init__()
@@ -120,6 +120,7 @@ class ExperimentSetup(RunEnvironment):
 
         # set post-processing instructions
         self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing")
+        self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing")
         self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
 
     def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py
index 9b5ab848..53e2ca7a 100644
--- a/src/run_modules/post_processing.py
+++ b/src/run_modules/post_processing.py
@@ -53,8 +53,9 @@ class PostProcessing(RunEnvironment):
 
         # bootstraps
         if self.data_store.get("evaluate_bootstraps", "general.postprocessing"):
-            bootstrap_path = self.data_store.get("bootstrap_path", "general")
-            BootStraps(self.test_data, bootstrap_path, 20)
+            bootstrap_path = self.data_store.get("bootstrap_path", "general.postprocessing")
+            number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
+            BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
             with TimeTracking(name="split (refac_1): create_boot_straps_refac_2()"):
                 self.create_boot_straps_refac_2()
                 self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
@@ -77,7 +78,8 @@ class PostProcessing(RunEnvironment):
             bootstrap_path = self.data_store.get("bootstrap_path", "general")
             forecast_path = self.data_store.get("forecast_path", "general")
             window_lead_time = self.data_store.get("window_lead_time", "general")
-            bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
+            number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
+            bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
             for station in bootstraps.stations:
                 with TimeTracking(name=station):
                     logging.info(station)
@@ -115,7 +117,8 @@ class PostProcessing(RunEnvironment):
             bootstrap_path = self.data_store.get("bootstrap_path", "general")
             forecast_path = self.data_store.get("forecast_path", "general")
             window_lead_time = self.data_store.get("window_lead_time", "general")
-            bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
+            number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
+            bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
 
             for station in bootstraps.stations:
                 with TimeTracking(name=station):
@@ -153,7 +156,8 @@ class PostProcessing(RunEnvironment):
             bootstrap_path = self.data_store.get("bootstrap_path", "general")
             forecast_path = self.data_store.get("forecast_path", "general")
             window_lead_time = self.data_store.get("window_lead_time", "general")
-            bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
+            number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
+            bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
 
             # calc skill scores
             skill_scores = statistics.SkillScores(None)
@@ -185,7 +189,8 @@ class PostProcessing(RunEnvironment):
             bootstrap_path = self.data_store.get("bootstrap_path", "general")
             forecast_path = self.data_store.get("forecast_path", "general")
             window_lead_time = self.data_store.get("window_lead_time", "general")
-            bootstraps = BootStraps(self.test_data, bootstrap_path, 20)
+            number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
+            bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
             skill_scores = statistics.SkillScores(None)
             score = {}
 
@@ -234,7 +239,6 @@ class PostProcessing(RunEnvironment):
                     score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"])
             return score
 
-
     def _load_model(self):
         try:
             model = self.data_store.get("best_model", "general")
-- 
GitLab