diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index 9c5d68688462ed33e91151dab685af5811cc3120..da26af924260629f8ce7b3062b0e2d3b91d8ebee 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -27,7 +27,7 @@ class ExperimentSetup(RunEnvironment): trainable: Train new model if true, otherwise try to load existing model """ - def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, station_type=None, variables=None, + def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None, statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None, window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, @@ -68,12 +68,11 @@ class ExperimentSetup(RunEnvironment): helpers.check_path_and_create(self.data_store.get("forecast_path", "general")) # setup for data - self._set_param("var_all_dict", var_all_dict, default=DEFAULT_VAR_ALL_DICT) self._set_param("stations", stations, default=DEFAULT_STATIONS) self._set_param("network", network, default="AIRBASE") self._set_param("station_type", station_type, default=None) - self._set_param("variables", variables, default=list(self.data_store.get("var_all_dict", "general").keys())) - self._set_param("statistics_per_var", statistics_per_var, default=self.data_store.get("var_all_dict", "general")) + self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) + self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var", "general").keys())) self._compare_variables_and_statistics() self._set_param("start", start, default="1997-01-01", scope="general") self._set_param("end", end, default="2017-12-31", scope="general") @@ -83,6 +82,7 @@ class ExperimentSetup(RunEnvironment): # target self._set_param("target_var", target_var, default="o3") + self._check_target_var() self._set_param("target_dim", target_dim, default='variables') self._set_param("window_lead_time", window_lead_time, default=3) @@ -132,16 +132,27 @@ class ExperimentSetup(RunEnvironment): return {} def _compare_variables_and_statistics(self): - logging.debug("check if all variables are included in statistics_per_var") - var = self.data_store.get("variables", "general") stat = self.data_store.get("statistics_per_var", "general") + var = self.data_store.get("variables", "general") if not set(var).issubset(stat.keys()): missing = set(var).difference(stat.keys()) raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested " f"variables are part of statistics_per_var. Please add also information on the missing " f"statistics for the variables: {missing}") + def _check_target_var(self): + target_var = helpers.to_list(self.data_store.get("target_var", "general")) + stat = self.data_store.get("statistics_per_var", "general") + var = self.data_store.get("variables", "general") + if not set(target_var).issubset(stat.keys()): + raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") + unused_vars = set(stat.keys()).difference(set(var).union(target_var)) + if len(unused_vars) > 0: + logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}") + stat_new = helpers.dict_pop(stat, list(unused_vars)) + self._set_param("statistics_per_var", stat_new) + if __name__ == "__main__":