From 3460756ad8fb7e1698231cc63a850740b7422dd3 Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Wed, 18 Mar 2020 15:59:33 +0100 Subject: [PATCH] number of bootstraps is now a parameter from experiment setup --- src/run_modules/experiment_setup.py | 3 ++- src/run_modules/post_processing.py | 18 +++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 53ef2c61..a420e287 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -37,7 +37,7 @@ class ExperimentSetup(RunEnvironment): test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, - evaluate_bootstraps=True, plot_list=None): + evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None): # create run framework super().__init__() @@ -120,6 +120,7 @@ class ExperimentSetup(RunEnvironment): # set post-processing instructions self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing") + self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 9b5ab848..53e2ca7a 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -53,8 +53,9 @@ class PostProcessing(RunEnvironment): # bootstraps if self.data_store.get("evaluate_bootstraps", "general.postprocessing"): - bootstrap_path = self.data_store.get("bootstrap_path", "general") - BootStraps(self.test_data, bootstrap_path, 20) + bootstrap_path = self.data_store.get("bootstrap_path", "general.postprocessing") + number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing") + BootStraps(self.test_data, bootstrap_path, number_of_bootstraps) with TimeTracking(name="split (refac_1): create_boot_straps_refac_2()"): self.create_boot_straps_refac_2() self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores() @@ -77,7 +78,8 @@ class PostProcessing(RunEnvironment): bootstrap_path = self.data_store.get("bootstrap_path", "general") forecast_path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") - bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing") + bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps) for station in bootstraps.stations: with TimeTracking(name=station): logging.info(station) @@ -115,7 +117,8 @@ class PostProcessing(RunEnvironment): bootstrap_path = self.data_store.get("bootstrap_path", "general") forecast_path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") - bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing") + bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps) for station in bootstraps.stations: with TimeTracking(name=station): @@ -153,7 +156,8 @@ class PostProcessing(RunEnvironment): bootstrap_path = self.data_store.get("bootstrap_path", "general") forecast_path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") - bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing") + bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps) # calc skill scores skill_scores = statistics.SkillScores(None) @@ -185,7 +189,8 @@ class PostProcessing(RunEnvironment): bootstrap_path = self.data_store.get("bootstrap_path", "general") forecast_path = self.data_store.get("forecast_path", "general") window_lead_time = self.data_store.get("window_lead_time", "general") - bootstraps = BootStraps(self.test_data, bootstrap_path, 20) + number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing") + bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps) skill_scores = statistics.SkillScores(None) score = {} @@ -234,7 +239,6 @@ class PostProcessing(RunEnvironment): score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"]) return score - def _load_model(self): try: model = self.data_store.get("best_model", "general") -- GitLab