import argparse import logging import os import pytest import mock from mlair.helpers import TimeTracking, to_list from mlair.configuration.path_config import prepare_host from mlair.run_modules.experiment_setup import ExperimentSetup class TestExperimentSetup: @pytest.fixture def empty_obj(self, caplog): obj = object.__new__(ExperimentSetup) super(ExperimentSetup, obj).__init__() caplog.set_level(logging.DEBUG) return obj def test_set_param_by_value(self, caplog, empty_obj): empty_obj._set_param("23tester", 23) assert caplog.record_tuples[-1] == ('root', 10, 'set experiment attribute: 23tester(general)=23') assert empty_obj.data_store.get("23tester", "general") == 23 def test_set_param_by_value_and_scope(self, caplog, empty_obj): empty_obj._set_param("109tester", 109, "general.testing") assert empty_obj.data_store.get("109tester", "general.tester") == 109 def test_set_param_with_default(self, caplog, empty_obj): empty_obj._set_param("NoneTester", None, "notNone", "general.testing") assert empty_obj.data_store.get("NoneTester", "general.testing") == "notNone" empty_obj._set_param("AnotherNoneTester", None) assert empty_obj.data_store.get("AnotherNoneTester", "general") is None def test_set_param_with_apply(self, caplog, empty_obj): empty_obj._set_param("NoneTester", None, default="notNone", apply=None) assert empty_obj.data_store.get("NoneTester") == "notNone" empty_obj._set_param("NoneTester", None, default="notNone", apply=to_list) assert empty_obj.data_store.get("NoneTester") == ["notNone"] empty_obj._set_param("NoneTester", None, apply=to_list) assert empty_obj.data_store.get("NoneTester") == [None] empty_obj._set_param("NoneTester", 2.3, apply=int) assert empty_obj.data_store.get("NoneTester") == 2 def test_init_default(self): exp_setup = ExperimentSetup() data_store = exp_setup.data_store # experiment setup assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("train_model", "general") is True assert data_store.get("create_new_model", "general") is True assert data_store.get("fraction_of_training", "general") == 0.8 # set experiment name assert data_store.get("experiment_name", "general") == "TestExperiment_daily" path = os.path.abspath(os.path.join(os.getcwd(), "TestExperiment_daily")) assert data_store.get("experiment_path", "general") == path default_statistics_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', 'cloudcover': 'average_values', 'pblheight': 'maximum'} # setup for data default_stations = ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] assert data_store.get("stations", "general") == default_stations assert data_store.get("variables", "general") == list(default_statistics_per_var.keys()) assert data_store.get("statistics_per_var", "general") == default_statistics_per_var assert data_store.get("start", "general") == "1997-01-01" assert data_store.get("end", "general") == "2017-12-31" assert data_store.get("window_history_size", "general") == 13 # target assert data_store.get("target_var", "general") == "o3" assert data_store.get("target_dim", "general") == "variables" assert data_store.get("window_lead_time", "general") == 3 # interpolation assert data_store.get("dimensions", "general") == {'new_index': ['datetime', 'Stations']} assert data_store.get("time_dim", "general") == "datetime" assert data_store.get("interpolation_method", "general") == "linear" assert data_store.get("interpolation_limit", "general") == 1 # train parameters assert data_store.get("start", "general.train") == "1997-01-01" assert data_store.get("end", "general.train") == "2007-12-31" assert data_store.get("min_length", "general.train") == 90 # validation parameters assert data_store.get("start", "general.val") == "2008-01-01" assert data_store.get("end", "general.val") == "2009-12-31" assert data_store.get("min_length", "general.val") == 90 # test parameters assert data_store.get("start", "general.test") == "2010-01-01" assert data_store.get("end", "general.test") == "2017-12-31" assert data_store.get("min_length", "general.test") == 90 # train_val parameters assert data_store.get("start", "general.train_val") == "1997-01-01" assert data_store.get("end", "general.train_val") == "2009-12-31" assert data_store.get("min_length", "general.train_val") == 180 # use all stations on all data sets (train, val, test) assert data_store.get("use_all_stations_on_all_data_sets", "general") is True def test_init_no_default(self): experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) kwargs = dict(experiment_date= "TODAY", statistics_per_var={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'}, stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", station_type="background", variables=["o3", "temp"], start="1999-01-01", end="2001-01-01", window_history_size=4, target_var="relhum", target_dim="target", window_lead_time=10, dimensions="dim1", time_dim="int_dim", interpolation_method="cubic", interpolation_limit=5, train_start="2000-01-01", train_end="2000-01-02", val_start="2000-01-03", val_end="2000-01-04", test_start="2000-01-05", test_end="2000-01-06", use_all_stations_on_all_data_sets=False, trainable=False, fraction_of_train=0.5, experiment_path=experiment_path, create_new_model=True, val_min_length=20) exp_setup = ExperimentSetup(**kwargs) data_store = exp_setup.data_store # experiment setup assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("train_model", "general") is True assert data_store.get("create_new_model", "general") is True assert data_store.get("fraction_of_training", "general") == 0.5 # set experiment name assert data_store.get("experiment_name", "general") == "TODAY_network_daily" path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder", "TODAY_network_daily")) assert data_store.get("experiment_path", "general") == path # setup for data assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027'] assert data_store.get("network", "general") == "INTERNET" assert data_store.get("station_type", "general") == "background" assert data_store.get("variables", "general") == ["o3", "temp"] assert data_store.get("statistics_per_var", "general") == {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'} assert data_store.get("start", "general") == "1999-01-01" assert data_store.get("end", "general") == "2001-01-01" assert data_store.get("window_history_size", "general") == 4 # target assert data_store.get("target_var", "general") == "relhum" assert data_store.get("target_dim", "general") == "target" assert data_store.get("window_lead_time", "general") == 10 # interpolation assert data_store.get("dimensions", "general") == "dim1" assert data_store.get("time_dim", "general") == "int_dim" assert data_store.get("interpolation_method", "general") == "cubic" assert data_store.get("interpolation_limit", "general") == 5 # train parameters assert data_store.get("start", "general.train") == "2000-01-01" assert data_store.get("end", "general.train") == "2000-01-02" assert data_store.get("min_length", "general.train") == 90 # validation parameters assert data_store.get("start", "general.val") == "2000-01-03" assert data_store.get("end", "general.val") == "2000-01-04" assert data_store.get("min_length", "general.val") == 20 # test parameters assert data_store.get("start", "general.test") == "2000-01-05" assert data_store.get("end", "general.test") == "2000-01-06" assert data_store.get("min_length", "general.test") == 90 # train_val parameters assert data_store.get("start", "general.train_val") == "2000-01-01" assert data_store.get("end", "general.train_val") == "2000-01-04" assert data_store.get("min_length", "general.train_val") == 110 # use all stations on all data sets (train, val, test) assert data_store.get("use_all_stations_on_all_data_sets", "general.test") is False def test_init_train_model_behaviour(self): exp_setup = ExperimentSetup(train_model=False, create_new_model=True) data_store = exp_setup.data_store assert data_store.get("train_model", "general") is True assert data_store.get("create_new_model", "general") is True exp_setup = ExperimentSetup(train_model=False, create_new_model=False) data_store = exp_setup.data_store assert data_store.get("train_model", "general") is False assert data_store.get("create_new_model", "general") is False exp_setup = ExperimentSetup(train_model=True, create_new_model=True) data_store = exp_setup.data_store assert data_store.get("train_model", "general") is True assert data_store.get("create_new_model", "general") is True exp_setup = ExperimentSetup(train_model=True, create_new_model=False) data_store = exp_setup.data_store assert data_store.get("train_model", "general") is True assert data_store.get("create_new_model", "general") is False def test_compare_variables_and_statistics(self): experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) kwargs = dict(experiment_date="TODAY", statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, stations=['DEBY053', 'DEBW059', 'DEBW027'], variables=["o3", "relhum"], experiment_path=experiment_path) with pytest.raises(ValueError) as e: ExperimentSetup(**kwargs) assert "for the variables: {'relhum'}" in e.value.args[0] kwargs["variables"] = ["o3", "temp"] assert ExperimentSetup(**kwargs) is not None def test_multiprocessing_no_debug(self): # no debug mode, parallel exp_setup = ExperimentSetup(use_multiprocessing_on_debug=False) assert exp_setup.data_store.get("use_multiprocessing") is True # no debug mode, serial exp_setup = ExperimentSetup(use_multiprocessing=False, use_multiprocessing_on_debug=True) assert exp_setup.data_store.get("use_multiprocessing") is False @mock.patch("sys.gettrace", return_value="dummy_not_null") def test_multiprocessing_debug(self, mock_gettrace): # debug mode, parallel exp_setup = ExperimentSetup(use_multiprocessing=False, use_multiprocessing_on_debug=True) assert exp_setup.data_store.get("use_multiprocessing") is True # debug mode, serial exp_setup = ExperimentSetup(use_multiprocessing=True) assert exp_setup.data_store.get("use_multiprocessing") is False