Skip to content
Snippets Groups Projects
Commit 1235b87c authored by lukas leufen's avatar lukas leufen
Browse files

experiment setup wasn't added to last commit

parent d0b5ade5
No related branches found
No related tags found
2 merge requests!59Develop,!54Lukas issue061 refac seperate input target vars
Pipeline #30910 passed
......@@ -27,7 +27,7 @@ class ExperimentSetup(RunEnvironment):
trainable: Train new model if true, otherwise try to load existing model
"""
def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, station_type=None, variables=None,
def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None,
statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None,
window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None,
limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
......@@ -68,12 +68,11 @@ class ExperimentSetup(RunEnvironment):
helpers.check_path_and_create(self.data_store.get("forecast_path", "general"))
# setup for data
self._set_param("var_all_dict", var_all_dict, default=DEFAULT_VAR_ALL_DICT)
self._set_param("stations", stations, default=DEFAULT_STATIONS)
self._set_param("network", network, default="AIRBASE")
self._set_param("station_type", station_type, default=None)
self._set_param("variables", variables, default=list(self.data_store.get("var_all_dict", "general").keys()))
self._set_param("statistics_per_var", statistics_per_var, default=self.data_store.get("var_all_dict", "general"))
self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var", "general").keys()))
self._compare_variables_and_statistics()
self._set_param("start", start, default="1997-01-01", scope="general")
self._set_param("end", end, default="2017-12-31", scope="general")
......@@ -83,6 +82,7 @@ class ExperimentSetup(RunEnvironment):
# target
self._set_param("target_var", target_var, default="o3")
self._check_target_var()
self._set_param("target_dim", target_dim, default='variables')
self._set_param("window_lead_time", window_lead_time, default=3)
......@@ -132,16 +132,27 @@ class ExperimentSetup(RunEnvironment):
return {}
def _compare_variables_and_statistics(self):
logging.debug("check if all variables are included in statistics_per_var")
var = self.data_store.get("variables", "general")
stat = self.data_store.get("statistics_per_var", "general")
var = self.data_store.get("variables", "general")
if not set(var).issubset(stat.keys()):
missing = set(var).difference(stat.keys())
raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested "
f"variables are part of statistics_per_var. Please add also information on the missing "
f"statistics for the variables: {missing}")
def _check_target_var(self):
target_var = helpers.to_list(self.data_store.get("target_var", "general"))
stat = self.data_store.get("statistics_per_var", "general")
var = self.data_store.get("variables", "general")
if not set(target_var).issubset(stat.keys()):
raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.")
unused_vars = set(stat.keys()).difference(set(var).union(target_var))
if len(unused_vars) > 0:
logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}")
stat_new = helpers.dict_pop(stat, list(unused_vars))
self._set_param("statistics_per_var", stat_new)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment