diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 56c22a81e48421438816855770b7477e84e3a8d8..53ef2c61af61f6f2005f6ee7fde81eb1cfefc340 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -20,6 +20,8 @@ DEFAULT_VAR_ALL_DICT = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'max 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 'pblheight': 'maximum'} DEFAULT_TRANSFORMATION = {"scope": "data", "method": "standardise", "mean": "estimate"} +DEFAULT_PLOT_LIST = ["PlotMonthlySummary", "PlotStationMap", "PlotClimatologicalSkillScore", "PlotTimeSeries", + "PlotCompetitiveSkillScore", "PlotBootstrapSkillScore", "plot_conditional_quantiles"] class ExperimentSetup(RunEnvironment): @@ -34,7 +36,8 @@ class ExperimentSetup(RunEnvironment): limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", - create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None): + create_new_model=None, bootstrap_path=None, permute_data_on_training=None, transformation=None, + evaluate_bootstraps=True, plot_list=None): # create run framework super().__init__() @@ -115,6 +118,10 @@ class ExperimentSetup(RunEnvironment): # use all stations on all data sets (train, val, test) self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True) + # set post-processing instructions + self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing") + self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") + def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: if value is None and default is not None: value = default diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index 0a61ee4f07d0c6eccf698aa16d3de9d7275e75f6..c7ac3f10799d4329e8744d6ea69286c57119645a 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -50,7 +50,8 @@ class PostProcessing(RunEnvironment): self.make_prediction() logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") - self.bootstrap_skill_scores = self.create_boot_straps() + if self.data_store.get("evaluate_bootstraps", "general.postprocessing"): + self.bootstrap_skill_scores = self.create_boot_straps() self.skill_scores = self.calculate_skill_scores() self.plot() @@ -124,19 +125,29 @@ class PostProcessing(RunEnvironment): logging.debug("Run plotting routines...") path = self.data_store.get("forecast_path", "general") - plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="obs", - forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) - plot_conditional_quantiles(self.test_data.stations, pred_name="obs", ref_name="CNN", - forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) - PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) - PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", self.target_var, - plot_folder=self.plot_path) - PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") - PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, - extra_name_tag="all_terms_", model_setup="CNN") - PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") - PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN") - PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, sampling=self._sampling) + plot_list = self.data_store.get("plot_list", "general.postprocessing") + + if "plot_conditional_quantiles" in plot_list: + plot_conditional_quantiles(self.test_data.stations, pred_name="CNN", ref_name="obs", + forecast_path=path, plot_name_affix="cali-ref", plot_folder=self.plot_path) + plot_conditional_quantiles(self.test_data.stations, pred_name="obs", ref_name="CNN", + forecast_path=path, plot_name_affix="like-bas", plot_folder=self.plot_path) + if "PlotStationMap" in plot_list: + PlotStationMap(generators={'b': self.test_data}, plot_folder=self.plot_path) + if "PlotMonthlySummary" in plot_list: + PlotMonthlySummary(self.test_data.stations, path, r"forecasts_%s_test.nc", self.target_var, + plot_folder=self.plot_path) + if "PlotClimatologicalSkillScore" in plot_list: + PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, model_setup="CNN") + PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, + extra_name_tag="all_terms_", model_setup="CNN") + if "PlotCompetitiveSkillScore" in plot_list: + PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") + if self.bootstrap_skill_scores is not None and "PlotBootstrapSkillScore" in plot_list: + PlotBootstrapSkillScore(self.bootstrap_skill_scores, plot_folder=self.plot_path, model_setup="CNN") + if "PlotTimeSeries" in plot_list: + PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path, + sampling=self._sampling) def calculate_test_score(self): test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(),