From 271c10d069338d7175f393a51a74274d326fcb3a Mon Sep 17 00:00:00 2001
From: leufen1 <l.leufen@fz-juelich.de>
Date: Fri, 16 Jul 2021 14:57:15 +0200
Subject: [PATCH] bootstrap method and type can be set during experiment setup

---
 mlair/configuration/defaults.py       | 2 ++
 mlair/run_modules/experiment_setup.py | 9 ++++++---
 mlair/run_modules/post_processing.py  | 9 ++++++---
 3 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/mlair/configuration/defaults.py b/mlair/configuration/defaults.py
index 088a504a..d61146b6 100644
--- a/mlair/configuration/defaults.py
+++ b/mlair/configuration/defaults.py
@@ -46,6 +46,8 @@ DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS = True
 DEFAULT_EVALUATE_BOOTSTRAPS = True
 DEFAULT_CREATE_NEW_BOOTSTRAPS = False
 DEFAULT_NUMBER_OF_BOOTSTRAPS = 20
+DEFAULT_BOOTSTRAP_TYPE = "singleinput"
+DEFAULT_BOOTSTRAP_METHOD = "shuffle"
 DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries",
                      "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "PlotConditionalQuantiles",
                      "PlotAvailability", "PlotAvailabilityHistogram", "PlotDataHistogram", "PlotPeriodogram"]
diff --git a/mlair/run_modules/experiment_setup.py b/mlair/run_modules/experiment_setup.py
index bd06914f..8036413c 100644
--- a/mlair/run_modules/experiment_setup.py
+++ b/mlair/run_modules/experiment_setup.py
@@ -19,7 +19,8 @@ from mlair.configuration.defaults import DEFAULT_STATIONS, DEFAULT_VAR_ALL_DICT,
     DEFAULT_VAL_MIN_LENGTH, DEFAULT_TEST_START, DEFAULT_TEST_END, DEFAULT_TEST_MIN_LENGTH, DEFAULT_TRAIN_VAL_MIN_LENGTH, \
     DEFAULT_USE_ALL_STATIONS_ON_ALL_DATA_SETS, DEFAULT_EVALUATE_BOOTSTRAPS, DEFAULT_CREATE_NEW_BOOTSTRAPS, \
     DEFAULT_NUMBER_OF_BOOTSTRAPS, DEFAULT_PLOT_LIST, DEFAULT_SAMPLING, DEFAULT_DATA_ORIGIN, DEFAULT_ITER_DIM, \
-    DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING
+    DEFAULT_USE_MULTIPROCESSING, DEFAULT_USE_MULTIPROCESSING_ON_DEBUG, DEFAULT_MAX_NUMBER_MULTIPROCESSING, \
+    DEFAULT_BOOTSTRAP_TYPE, DEFAULT_BOOTSTRAP_METHOD
 from mlair.data_handler import DefaultDataHandler
 from mlair.run_modules.run_environment import RunEnvironment
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16 as VanillaModel
@@ -211,8 +212,8 @@ class ExperimentSetup(RunEnvironment):
                  create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
                  train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
                  extremes_on_right_tail_only: bool = None, evaluate_bootstraps=None, plot_list=None,
-                 number_of_bootstraps=None,
-                 create_new_bootstraps=None, data_path: str = None, batch_path: str = None, login_nodes=None,
+                 number_of_bootstraps=None, create_new_bootstraps=None, bootstrap_method=None, bootstrap_type=None,
+                 data_path: str = None, batch_path: str = None, login_nodes=None,
                  hpc_hosts=None, model=None, batch_size=None, epochs=None, data_handler=None,
                  data_origin: Dict = None, competitors: list = None, competitor_path: str = None,
                  use_multiprocessing: bool = None, use_multiprocessing_on_debug: bool = None,
@@ -347,6 +348,8 @@ class ExperimentSetup(RunEnvironment):
         self._set_param("create_new_bootstraps", create_new_bootstraps, scope="general.postprocessing")
         self._set_param("number_of_bootstraps", number_of_bootstraps, default=DEFAULT_NUMBER_OF_BOOTSTRAPS,
                         scope="general.postprocessing")
+        self._set_param("bootstrap_method", bootstrap_method, default=DEFAULT_BOOTSTRAP_METHOD)
+        self._set_param("bootstrap_type", bootstrap_type, default=DEFAULT_BOOTSTRAP_TYPE)
         self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
         self._set_param("neighbors", ["DEBW030"])  # TODO: just for testing
 
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index c4ce0088..57b4d6ef 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -103,7 +103,10 @@ class PostProcessing(RunEnvironment):
         if self.data_store.get("evaluate_bootstraps", "postprocessing"):
             with TimeTracking(name="calculate bootstraps"):
                 create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing")
-                self.bootstrap_postprocessing(create_new_bootstraps)  # todo: make flexible and add boot method and type
+                bootstrap_method = self.data_store.get("bootstrap_method", "postprocessing")
+                bootstrap_type = self.data_store.get("bootstrap_type", "postprocessing")
+                self.bootstrap_postprocessing(create_new_bootstraps, bootstrap_type=bootstrap_type,
+                                              bootstrap_method=bootstrap_method)
 
         # skill scores and error metrics
         with TimeTracking(name="calculate skill scores"):
@@ -136,8 +139,8 @@ class PostProcessing(RunEnvironment):
                 continue
         return xr.concat(competing_predictions, "type") if len(competing_predictions) > 0 else None
 
-    def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type=None,
-                                 bootstrap_method=None) -> None:
+    def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0, bootstrap_type="singleinput",
+                                 bootstrap_method="shuffle") -> None:
         """
         Calculate skill scores of bootstrapped data.
 
-- 
GitLab