Skip to content
Snippets Groups Projects
Commit 3460756a authored by lukas leufen's avatar lukas leufen
Browse files

number of bootstraps is now a parameter from experiment setup

parent 79fb9653
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!61Resolve "REFAC: clean-up bootstrap workflow"
Pipeline #32248 passed
...@@ -37,7 +37,7 @@ class ExperimentSetup(RunEnvironment): ...@@ -37,7 +37,7 @@ class ExperimentSetup(RunEnvironment):
test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None,
experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily",
create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None,
evaluate_bootstraps=True, plot_list=None): evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None):
# create run framework # create run framework
super().__init__() super().__init__()
...@@ -120,6 +120,7 @@ class ExperimentSetup(RunEnvironment): ...@@ -120,6 +120,7 @@ class ExperimentSetup(RunEnvironment):
# set post-processing instructions # set post-processing instructions
self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing") self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing")
self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing")
self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing")
def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
......
...@@ -53,8 +53,9 @@ class PostProcessing(RunEnvironment): ...@@ -53,8 +53,9 @@ class PostProcessing(RunEnvironment):
# bootstraps # bootstraps
if self.data_store.get("evaluate_bootstraps", "general.postprocessing"): if self.data_store.get("evaluate_bootstraps", "general.postprocessing"):
bootstrap_path = self.data_store.get("bootstrap_path", "general") bootstrap_path = self.data_store.get("bootstrap_path", "general.postprocessing")
BootStraps(self.test_data, bootstrap_path, 20) number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
with TimeTracking(name="split (refac_1): create_boot_straps_refac_2()"): with TimeTracking(name="split (refac_1): create_boot_straps_refac_2()"):
self.create_boot_straps_refac_2() self.create_boot_straps_refac_2()
self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores() self.bootstrap_skill_scores = self.calculate_bootstrap_skill_scores()
...@@ -77,7 +78,8 @@ class PostProcessing(RunEnvironment): ...@@ -77,7 +78,8 @@ class PostProcessing(RunEnvironment):
bootstrap_path = self.data_store.get("bootstrap_path", "general") bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general") forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20) number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
for station in bootstraps.stations: for station in bootstraps.stations:
with TimeTracking(name=station): with TimeTracking(name=station):
logging.info(station) logging.info(station)
...@@ -115,7 +117,8 @@ class PostProcessing(RunEnvironment): ...@@ -115,7 +117,8 @@ class PostProcessing(RunEnvironment):
bootstrap_path = self.data_store.get("bootstrap_path", "general") bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general") forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20) number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
for station in bootstraps.stations: for station in bootstraps.stations:
with TimeTracking(name=station): with TimeTracking(name=station):
...@@ -153,7 +156,8 @@ class PostProcessing(RunEnvironment): ...@@ -153,7 +156,8 @@ class PostProcessing(RunEnvironment):
bootstrap_path = self.data_store.get("bootstrap_path", "general") bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general") forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20) number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
# calc skill scores # calc skill scores
skill_scores = statistics.SkillScores(None) skill_scores = statistics.SkillScores(None)
...@@ -185,7 +189,8 @@ class PostProcessing(RunEnvironment): ...@@ -185,7 +189,8 @@ class PostProcessing(RunEnvironment):
bootstrap_path = self.data_store.get("bootstrap_path", "general") bootstrap_path = self.data_store.get("bootstrap_path", "general")
forecast_path = self.data_store.get("forecast_path", "general") forecast_path = self.data_store.get("forecast_path", "general")
window_lead_time = self.data_store.get("window_lead_time", "general") window_lead_time = self.data_store.get("window_lead_time", "general")
bootstraps = BootStraps(self.test_data, bootstrap_path, 20) number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing")
bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps)
skill_scores = statistics.SkillScores(None) skill_scores = statistics.SkillScores(None)
score = {} score = {}
...@@ -234,7 +239,6 @@ class PostProcessing(RunEnvironment): ...@@ -234,7 +239,6 @@ class PostProcessing(RunEnvironment):
score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"]) score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"])
return score return score
def _load_model(self): def _load_model(self):
try: try:
model = self.data_store.get("best_model", "general") model = self.data_store.get("best_model", "general")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment