Skip to content
Snippets Groups Projects
Commit 1235b87c authored by lukas leufen's avatar lukas leufen
Browse files

experiment setup wasn't added to last commit

parent d0b5ade5
Branches
Tags
2 merge requests!59Develop,!54Lukas issue061 refac seperate input target vars
Pipeline #30910 passed
......@@ -27,7 +27,7 @@ class ExperimentSetup(RunEnvironment):
trainable: Train new model if true, otherwise try to load existing model
"""
def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, station_type=None, variables=None,
def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None,
statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None,
window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None,
limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
......@@ -68,12 +68,11 @@ class ExperimentSetup(RunEnvironment):
helpers.check_path_and_create(self.data_store.get("forecast_path", "general"))
# setup for data
self._set_param("var_all_dict", var_all_dict, default=DEFAULT_VAR_ALL_DICT)
self._set_param("stations", stations, default=DEFAULT_STATIONS)
self._set_param("network", network, default="AIRBASE")
self._set_param("station_type", station_type, default=None)
self._set_param("variables", variables, default=list(self.data_store.get("var_all_dict", "general").keys()))
self._set_param("statistics_per_var", statistics_per_var, default=self.data_store.get("var_all_dict", "general"))
self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var", "general").keys()))
self._compare_variables_and_statistics()
self._set_param("start", start, default="1997-01-01", scope="general")
self._set_param("end", end, default="2017-12-31", scope="general")
......@@ -83,6 +82,7 @@ class ExperimentSetup(RunEnvironment):
# target
self._set_param("target_var", target_var, default="o3")
self._check_target_var()
self._set_param("target_dim", target_dim, default='variables')
self._set_param("window_lead_time", window_lead_time, default=3)
......@@ -132,16 +132,27 @@ class ExperimentSetup(RunEnvironment):
return {}
def _compare_variables_and_statistics(self):
logging.debug("check if all variables are included in statistics_per_var")
var = self.data_store.get("variables", "general")
stat = self.data_store.get("statistics_per_var", "general")
var = self.data_store.get("variables", "general")
if not set(var).issubset(stat.keys()):
missing = set(var).difference(stat.keys())
raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested "
f"variables are part of statistics_per_var. Please add also information on the missing "
f"statistics for the variables: {missing}")
def _check_target_var(self):
target_var = helpers.to_list(self.data_store.get("target_var", "general"))
stat = self.data_store.get("statistics_per_var", "general")
var = self.data_store.get("variables", "general")
if not set(target_var).issubset(stat.keys()):
raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")
unused_vars = set(stat.keys()).difference(set(var).union(target_var))
if len(unused_vars) > 0:
logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}")
stat_new = helpers.dict_pop(stat, list(unused_vars))
self._set_param("statistics_per_var", stat_new)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment