Skip to content
Snippets Groups Projects
Commit 1faf7c2d authored by lukas leufen's avatar lukas leufen
Browse files

modified experiment setup

parent 134288fa
Branches
Tags
2 merge requests!17update to v0.4.0,!15new feat split subsets
Pipeline #26462 passed
...@@ -27,10 +27,12 @@ class ExperimentSetup(RunEnvironment): ...@@ -27,10 +27,12 @@ class ExperimentSetup(RunEnvironment):
trainable: Train new model if true, otherwise try to load existing model trainable: Train new model if true, otherwise try to load existing model
""" """
def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, variables=None, target_var="o3", def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, variables=None,
target_dim=None, dimensions=None, interpolate_dim=None, train_start=None, train_end=None, statistics_per_var=None, start=None, end=None, window_history=None, target_var="o3", target_dim=None,
val_start=None, val_end=None, test_start=None, test_end=None, use_all_stations_on_all_data_sets=True, window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None,
trainable=False, fraction_of_train=None, experiment_path=None): limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None,
test_end=None, use_all_stations_on_all_data_sets=True, trainable=False, fraction_of_train=None,
experiment_path=None):
# create run framework # create run framework
super().__init__() super().__init__()
...@@ -52,14 +54,21 @@ class ExperimentSetup(RunEnvironment): ...@@ -52,14 +54,21 @@ class ExperimentSetup(RunEnvironment):
self._set_param("stations", stations, default=DEFAULT_STATIONS) self._set_param("stations", stations, default=DEFAULT_STATIONS)
self._set_param("network", network, default="AIRBASE") self._set_param("network", network, default="AIRBASE")
self._set_param("variables", variables, default=list(self.data_store.get("var_all_dict", "general").keys())) self._set_param("variables", variables, default=list(self.data_store.get("var_all_dict", "general").keys()))
self._set_param("statistics_per_var", statistics_per_var, default=self.data_store.get("var_all_dict", "general"))
self._set_param("start", start, default="1997-01-01", scope="general")
self._set_param("end", end, default="2017-12-31", scope="general")
self._set_param("window_history", window_history, default=13)
# target # target
self._set_param("target_var", target_var, default="o3") self._set_param("target_var", target_var, default="o3")
self._set_param("target_dim", target_dim, default='variables') self._set_param("target_dim", target_dim, default='variables')
self._set_param("window_lead_time", window_lead_time, default=3)
# interpolation # interpolation
self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']}) self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']})
self._set_param("interpolate_dim", interpolate_dim, default='datetime') self._set_param("interpolate_dim", interpolate_dim, default='datetime')
self._set_param("interpolate_method", interpolate_method, default='linear')
self._set_param("limit_nan_fill", limit_nan_fill, default=1)
# train parameters # train parameters
self._set_param("start", train_start, default="1997-01-01", scope="general.train") self._set_param("start", train_start, default="1997-01-01", scope="general.train")
...@@ -69,7 +78,7 @@ class ExperimentSetup(RunEnvironment): ...@@ -69,7 +78,7 @@ class ExperimentSetup(RunEnvironment):
self._set_param("start", val_start, default="2008-01-01", scope="general.val") self._set_param("start", val_start, default="2008-01-01", scope="general.val")
self._set_param("end", val_end, default="2009-12-31", scope="general.val") self._set_param("end", val_end, default="2009-12-31", scope="general.val")
# validation parameters # test parameters
self._set_param("start", test_start, default="2010-01-01", scope="general.test") self._set_param("start", test_start, default="2010-01-01", scope="general.test")
self._set_param("end", test_end, default="2017-12-31", scope="general.test") self._set_param("end", test_end, default="2017-12-31", scope="general.test")
...@@ -83,15 +92,13 @@ class ExperimentSetup(RunEnvironment): ...@@ -83,15 +92,13 @@ class ExperimentSetup(RunEnvironment):
logging.debug(f"set experiment attribute: {param}({scope})={value}") logging.debug(f"set experiment attribute: {param}({scope})={value}")
@staticmethod @staticmethod
def _get_parser_args(args: Union[Dict, argparse.Namespace, argparse.ArgumentParser]) -> Dict: def _get_parser_args(args: Union[Dict, argparse.Namespace]) -> Dict:
""" """
Transform args to dict if given as argparse.Namespace Transform args to dict if given as argparse.Namespace
:param args: either a dictionary or an argument parser instance :param args: either a dictionary or an argument parser instance
:return: dictionary with all arguments :return: dictionary with all arguments
""" """
if isinstance(args, argparse.ArgumentParser): if isinstance(args, argparse.Namespace):
return args.parse_args().__dict__
elif isinstance(args, argparse.Namespace):
return args.__dict__ return args.__dict__
elif isinstance(args, dict): elif isinstance(args, dict):
return args return args
......
...@@ -45,15 +45,18 @@ class TestExperimentSetup: ...@@ -45,15 +45,18 @@ class TestExperimentSetup:
def test_init_default(self): def test_init_default(self):
exp_setup = ExperimentSetup() exp_setup = ExperimentSetup()
data_store = exp_setup.data_store data_store = exp_setup.data_store
# experiment setup
assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("data_path", "general") == prepare_host()
assert data_store.get("trainable", "general") is False assert data_store.get("trainable", "general") is False
assert data_store.get("fraction_of_train", "general") == 0.8 assert data_store.get("fraction_of_train", "general") == 0.8
# set experiment name
assert data_store.get("experiment_name", "general") == "TestExperiment" assert data_store.get("experiment_name", "general") == "TestExperiment"
path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment")) path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment"))
assert data_store.get("experiment_path", "general") == path assert data_store.get("experiment_path", "general") == path
default_var_all_dict = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', default_var_all_dict = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values',
'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values',
'pblheight': 'maximum'} 'pblheight': 'maximum'}
# setup for data
assert data_store.get("var_all_dict", "general") == default_var_all_dict assert data_store.get("var_all_dict", "general") == default_var_all_dict
default_stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022', default_stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022',
'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039', 'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039',
...@@ -65,50 +68,80 @@ class TestExperimentSetup: ...@@ -65,50 +68,80 @@ class TestExperimentSetup:
assert data_store.get("stations", "general") == default_stations assert data_store.get("stations", "general") == default_stations
assert data_store.get("network", "general") == "AIRBASE" assert data_store.get("network", "general") == "AIRBASE"
assert data_store.get("variables", "general") == list(default_var_all_dict.keys()) assert data_store.get("variables", "general") == list(default_var_all_dict.keys())
assert data_store.get("statistics_per_var", "general") == default_var_all_dict
assert data_store.get("start", "general") == "1997-01-01"
assert data_store.get("end", "general") == "2017-12-31"
assert data_store.get("window_history", "general") == 13
# target
assert data_store.get("target_var", "general") == "o3" assert data_store.get("target_var", "general") == "o3"
assert data_store.get("target_dim", "general") == "variables" assert data_store.get("target_dim", "general") == "variables"
assert data_store.get("window_lead_time", "general") == 3
# interpolation
assert data_store.get("dimensions", "general") == {'new_index': ['datetime', 'Stations']} assert data_store.get("dimensions", "general") == {'new_index': ['datetime', 'Stations']}
assert data_store.get("interpolate_dim", "general") == "datetime" assert data_store.get("interpolate_dim", "general") == "datetime"
with pytest.raises(NameNotFoundInScope): assert data_store.get("interpolate_method", "general") == "linear"
data_store.get("start", "general") assert data_store.get("limit_nan_fill", "general") == 1
with pytest.raises(NameNotFoundInScope): # train parameters
data_store.get("end", "general")
assert data_store.get("start", "general.train") == "1997-01-01" assert data_store.get("start", "general.train") == "1997-01-01"
assert data_store.get("end", "general.train") == "2007-12-31" assert data_store.get("end", "general.train") == "2007-12-31"
# validation parameters
assert data_store.get("start", "general.val") == "2008-01-01" assert data_store.get("start", "general.val") == "2008-01-01"
assert data_store.get("end", "general.val") == "2009-12-31" assert data_store.get("end", "general.val") == "2009-12-31"
# test parameters
assert data_store.get("start", "general.test") == "2010-01-01" assert data_store.get("start", "general.test") == "2010-01-01"
assert data_store.get("end", "general.test") == "2017-12-31" assert data_store.get("end", "general.test") == "2017-12-31"
# use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general") is True
def test_init_no_default(self): def test_init_no_default(self):
experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
kwargs = dict(parser_args={"experiment_date": "TODAY"}, kwargs = dict(parser_args={"experiment_date": "TODAY"},
var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'}, var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'},
stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", variables=["o3", "temp"], stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", variables=["o3", "temp"],
target_var="temp", target_dim="target", dimensions="dim1", interpolate_dim="int_dim", statistics_per_var=None, start="1999-01-01", end="2001-01-01", window_history=4,
train_start="2000-01-01", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1",
test_start="2000-01-05", test_end="2000-01-06", use_all_stations_on_all_data_sets=False, interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
trainable=True, fraction_of_train=0.5, experiment_path=experiment_path) train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05",
test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=True,
fraction_of_train=0.5, experiment_path=experiment_path)
exp_setup = ExperimentSetup(**kwargs) exp_setup = ExperimentSetup(**kwargs)
data_store = exp_setup.data_store data_store = exp_setup.data_store
# experiment setup
assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("data_path", "general") == prepare_host()
assert data_store.get("trainable", "general") is True assert data_store.get("trainable", "general") is True
assert data_store.get("fraction_of_train", "general") == 0.5 assert data_store.get("fraction_of_train", "general") == 0.5
# set experiment name
assert data_store.get("experiment_name", "general") == "TODAY_network/" assert data_store.get("experiment_name", "general") == "TODAY_network/"
path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
assert data_store.get("experiment_path", "general") == path assert data_store.get("experiment_path", "general") == path
# setup for data
assert data_store.get("var_all_dict", "general") == {'o3': 'dma8eu', 'relhum': 'average_values', assert data_store.get("var_all_dict", "general") == {'o3': 'dma8eu', 'relhum': 'average_values',
'temp': 'maximum'} 'temp': 'maximum'}
assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027'] assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027']
assert data_store.get("network", "general") == "INTERNET" assert data_store.get("network", "general") == "INTERNET"
assert data_store.get("variables", "general") == ["o3", "temp"] assert data_store.get("variables", "general") == ["o3", "temp"]
assert data_store.get("statistics_per_var", "general") == {'o3': 'dma8eu', 'relhum': 'average_values',
'temp': 'maximum'}
assert data_store.get("start", "general") == "1999-01-01"
assert data_store.get("end", "general") == "2001-01-01"
assert data_store.get("window_history", "general") == 4
# target
assert data_store.get("target_var", "general") == "temp" assert data_store.get("target_var", "general") == "temp"
assert data_store.get("target_dim", "general") == "target" assert data_store.get("target_dim", "general") == "target"
assert data_store.get("window_lead_time", "general") == 10
# interpolation
assert data_store.get("dimensions", "general") == "dim1" assert data_store.get("dimensions", "general") == "dim1"
assert data_store.get("interpolate_dim", "general") == "int_dim" assert data_store.get("interpolate_dim", "general") == "int_dim"
assert data_store.get("interpolate_method", "general") == "cubic"
assert data_store.get("limit_nan_fill", "general") == 5
# train parameters
assert data_store.get("start", "general.train") == "2000-01-01" assert data_store.get("start", "general.train") == "2000-01-01"
assert data_store.get("end", "general.train") == "2000-01-02" assert data_store.get("end", "general.train") == "2000-01-02"
# validation parameters
assert data_store.get("start", "general.val") == "2000-01-03" assert data_store.get("start", "general.val") == "2000-01-03"
assert data_store.get("end", "general.val") == "2000-01-04" assert data_store.get("end", "general.val") == "2000-01-04"
# test parameters
assert data_store.get("start", "general.test") == "2000-01-05" assert data_store.get("start", "general.test") == "2000-01-05"
assert data_store.get("end", "general.test") == "2000-01-06" assert data_store.get("end", "general.test") == "2000-01-06"
# use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment