Select Git revision
FrameElement.cpp
-
Ulrich Kemloh authoredUlrich Kemloh authored
test_experiment_setup.py 11.56 KiB
import argparse
import logging
import os
import pytest
from src.helpers import TimeTracking
from src.configuration.path_config import prepare_host
from src.run_modules.experiment_setup import ExperimentSetup
class TestExperimentSetup:
@pytest.fixture
def empty_obj(self, caplog):
obj = object.__new__(ExperimentSetup)
obj.time = TimeTracking()
caplog.set_level(logging.DEBUG)
return obj
def test_set_param_by_value(self, caplog, empty_obj):
empty_obj._set_param("23tester", 23)
assert caplog.record_tuples[-1] == ('root', 10, 'set experiment attribute: 23tester(general)=23')
assert empty_obj.data_store.get("23tester", "general") == 23
def test_set_param_by_value_and_scope(self, caplog, empty_obj):
empty_obj._set_param("109tester", 109, "general.testing")
assert empty_obj.data_store.get("109tester", "general.tester") == 109
def test_set_param_with_default(self, caplog, empty_obj):
empty_obj._set_param("NoneTester", None, "notNone", "general.testing")
assert empty_obj.data_store.get("NoneTester", "general.testing") == "notNone"
empty_obj._set_param("AnotherNoneTester", None)
assert empty_obj.data_store.get("AnotherNoneTester", "general") is None
def test_get_parser_args_from_dict(self, empty_obj):
res = empty_obj._get_parser_args({'test2': 2, 'test10str': "10"})
assert res == {'test2': 2, 'test10str': "10"}
def test_get_parser_args_from_parse_args(self, empty_obj):
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_date', type=str)
parser_args = parser.parse_args(["--experiment_date", "TOMORROW"])
assert empty_obj._get_parser_args(parser_args) == {"experiment_date": "TOMORROW"}
def test_init_default(self):
exp_setup = ExperimentSetup()
data_store = exp_setup.data_store
# experiment setup
assert data_store.get("data_path", "general") == prepare_host()
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is True
assert data_store.get("fraction_of_training", "general") == 0.8
# set experiment name
assert data_store.get("experiment_name", "general") == "TestExperiment_daily"
path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment_daily"))
assert data_store.get("experiment_path", "general") == path
default_statistics_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum',
'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu',
'cloudcover': 'average_values', 'pblheight': 'maximum'}
# setup for data
default_stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBY052', 'DEBY032', 'DEBW022',
'DEBY004', 'DEBY020', 'DEBW030', 'DEBW037', 'DEBW031', 'DEBW015', 'DEBW073', 'DEBY039',
'DEBW038', 'DEBW081', 'DEBY075', 'DEBW040', 'DEBY053', 'DEBW059', 'DEBW027', 'DEBY072',
'DEBW042', 'DEBW039', 'DEBY001', 'DEBY113', 'DEBY089', 'DEBW024', 'DEBW004', 'DEBY037',
'DEBW056', 'DEBW029', 'DEBY068', 'DEBW010', 'DEBW026', 'DEBY002', 'DEBY079', 'DEBW084',
'DEBY049', 'DEBY031', 'DEBW019', 'DEBW001', 'DEBY063', 'DEBY005', 'DEBW046', 'DEBW103',
'DEBW052', 'DEBW034', 'DEBY088', ]
assert data_store.get("stations", "general") == default_stations
assert data_store.get("network", "general") == "AIRBASE"
assert data_store.get("station_type", "general") is None
assert data_store.get("variables", "general") == list(default_statistics_per_var.keys())
assert data_store.get("statistics_per_var", "general") == default_statistics_per_var
assert data_store.get("start", "general") == "1997-01-01"
assert data_store.get("end", "general") == "2017-12-31"
assert data_store.get("window_history_size", "general") == 13
# target
assert data_store.get("target_var", "general") == "o3"
assert data_store.get("target_dim", "general") == "variables"
assert data_store.get("window_lead_time", "general") == 3
# interpolation
assert data_store.get("dimensions", "general") == {'new_index': ['datetime', 'Stations']}
assert data_store.get("interpolate_dim", "general") == "datetime"
assert data_store.get("interpolate_method", "general") == "linear"
assert data_store.get("limit_nan_fill", "general") == 1
# train parameters
assert data_store.get("start", "general.train") == "1997-01-01"
assert data_store.get("end", "general.train") == "2007-12-31"
assert data_store.get("min_length", "general.train") == 90
# validation parameters
assert data_store.get("start", "general.val") == "2008-01-01"
assert data_store.get("end", "general.val") == "2009-12-31"
assert data_store.get("min_length", "general.val") == 90
# test parameters
assert data_store.get("start", "general.test") == "2010-01-01"
assert data_store.get("end", "general.test") == "2017-12-31"
assert data_store.get("min_length", "general.test") == 90
# train_val parameters
assert data_store.get("start", "general.train_val") == "1997-01-01"
assert data_store.get("end", "general.train_val") == "2009-12-31"
assert data_store.get("min_length", "general.train_val") == 180
# use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general") is True
def test_init_no_default(self):
experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
kwargs = dict(parser_args={"experiment_date": "TODAY"},
statistics_per_var={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'},
stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", station_type="background",
variables=["o3", "temp"], start="1999-01-01", end="2001-01-01", window_history_size=4,
target_var="relhum", target_dim="target", window_lead_time=10, dimensions="dim1",
interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05",
test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False,
fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True, val_min_length=20)
exp_setup = ExperimentSetup(**kwargs)
data_store = exp_setup.data_store
# experiment setup
assert data_store.get("data_path", "general") == prepare_host()
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is True
assert data_store.get("fraction_of_training", "general") == 0.5
# set experiment name
assert data_store.get("experiment_name", "general") == "TODAY_network_daily"
path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder",
"TODAY_network_daily"))
assert data_store.get("experiment_path", "general") == path
# setup for data
assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027']
assert data_store.get("network", "general") == "INTERNET"
assert data_store.get("station_type", "general") == "background"
assert data_store.get("variables", "general") == ["o3", "temp"]
assert data_store.get("statistics_per_var", "general") == {'o3': 'dma8eu', 'relhum': 'average_values',
'temp': 'maximum'}
assert data_store.get("start", "general") == "1999-01-01"
assert data_store.get("end", "general") == "2001-01-01"
assert data_store.get("window_history_size", "general") == 4
# target
assert data_store.get("target_var", "general") == "relhum"
assert data_store.get("target_dim", "general") == "target"
assert data_store.get("window_lead_time", "general") == 10
# interpolation
assert data_store.get("dimensions", "general") == "dim1"
assert data_store.get("interpolate_dim", "general") == "int_dim"
assert data_store.get("interpolate_method", "general") == "cubic"
assert data_store.get("limit_nan_fill", "general") == 5
# train parameters
assert data_store.get("start", "general.train") == "2000-01-01"
assert data_store.get("end", "general.train") == "2000-01-02"
assert data_store.get("min_length", "general.train") == 90
# validation parameters
assert data_store.get("start", "general.val") == "2000-01-03"
assert data_store.get("end", "general.val") == "2000-01-04"
assert data_store.get("min_length", "general.val") == 20
# test parameters
assert data_store.get("start", "general.test") == "2000-01-05"
assert data_store.get("end", "general.test") == "2000-01-06"
assert data_store.get("min_length", "general.test") == 90
# train_val parameters
assert data_store.get("start", "general.train_val") == "2000-01-01"
assert data_store.get("end", "general.train_val") == "2000-01-04"
assert data_store.get("min_length", "general.train_val") == 110
# use all stations on all data sets (train, val, test)
assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False
def test_init_trainable_behaviour(self):
exp_setup = ExperimentSetup(trainable=False, create_new_model=True)
data_store = exp_setup.data_store
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is True
exp_setup = ExperimentSetup(trainable=False, create_new_model=False)
data_store = exp_setup.data_store
assert data_store.get("trainable", "general") is False
assert data_store.get("create_new_model", "general") is False
exp_setup = ExperimentSetup(trainable=True, create_new_model=True)
data_store = exp_setup.data_store
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is True
exp_setup = ExperimentSetup(trainable=True, create_new_model=False)
data_store = exp_setup.data_store
assert data_store.get("trainable", "general") is True
assert data_store.get("create_new_model", "general") is False
def test_compare_variables_and_statistics(self):
experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
kwargs = dict(parser_args={"experiment_date": "TODAY"},
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"],
experiment_path=experiment_path)
with pytest.raises(ValueError) as e:
ExperimentSetup(**kwargs)
assert "for the variables: {'relhum'}" in e.value.args[0]
kwargs["variables"] = ["o3", "temp"]
assert ExperimentSetup(**kwargs) is not None