import logging import pytest from src.helpers import PyTestRegex from src.modules.experiment_setup import ExperimentSetup from src.modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST from src.data_generator import DataGenerator from src.datastore import NameNotFoundInScope from src.modules.run_environment import RunEnvironment class TestPreProcessing: @pytest.fixture def obj_no_init(self): return object.__new__(PreProcessing) @pytest.fixture def obj_super_init(self): obj = object.__new__(PreProcessing) super(PreProcessing, obj).__init__() obj.data_store.put("NAME1", 1, "general") obj.data_store.put("NAME2", 2, "general") obj.data_store.put("NAME3", 3, "general") obj.data_store.put("NAME1", 10, "general.sub") obj.data_store.put("NAME4", 4, "general.sub.sub") yield obj RunEnvironment().__del__() @pytest.fixture def obj_with_exp_setup(self): ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background") pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre RunEnvironment().__del__() def test_init(self, caplog): ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) caplog.set_level(logging.INFO) PreProcessing() assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') assert caplog.record_tuples[-2] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) RunEnvironment().__del__() def test_run(self, obj_with_exp_setup): assert obj_with_exp_setup.data_store.search_name("generator") == [] assert obj_with_exp_setup._run() is None assert obj_with_exp_setup.data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test"]) def test_split_train_val_test(self, obj_with_exp_setup): assert obj_with_exp_setup.data_store.search_name("generator") == [] obj_with_exp_setup.split_train_val_test() data_store = obj_with_exp_setup.data_store assert data_store.search_scope("general.train") == sorted(["generator", "start", "end", "stations"]) assert data_store.search_name("generator") == sorted(["general.train", "general.val", "general.test"]) def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.data_store.put("use_all_stations_on_all_data_sets", False, "general.awesome") obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) with pytest.raises(NameNotFoundInScope): data_store.get("generator", "general") assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBY081"] def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', " "'DEBW076', 'DEBW087', 'DEBW001']") data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) with pytest.raises(NameNotFoundInScope): data_store.get("generator", "general") assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] def test_check_valid_stations(self, caplog, obj_with_exp_setup): pre = obj_with_exp_setup caplog.set_level(logging.INFO) args = pre._create_args_dict(DEFAULT_ARGS_LIST) kwargs = pre._create_args_dict(DEFAULT_KWARGS_LIST) stations = pre.data_store.get("stations", "general") valid_stations = pre.check_valid_stations(args, kwargs, stations) assert len(valid_stations) < len(stations) assert valid_stations == stations[:-1] assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found ' r'5/6 valid stations.')) def test_split_set_indices(self, obj_no_init): dummy_list = list(range(0, 15)) train, val, test = obj_no_init.split_set_indices(len(dummy_list), 0.9) assert dummy_list[train] == list(range(0, 10)) assert dummy_list[val] == list(range(10, 13)) assert dummy_list[test] == list(range(13, 15)) def test_create_args_dict_default_scope(self, obj_super_init): assert obj_super_init._create_args_dict(["NAME1", "NAME2"]) == {"NAME1": 1, "NAME2": 2} def test_create_args_dict_given_scope(self, obj_super_init): assert obj_super_init._create_args_dict(["NAME1", "NAME2"], scope="general.sub") == {"NAME1": 10, "NAME2": 2} def test_create_args_dict_missing_entry(self, obj_super_init): assert obj_super_init._create_args_dict(["NAME5", "NAME2"]) == {"NAME2": 2} assert obj_super_init._create_args_dict(["NAME4", "NAME2"]) == {"NAME2": 2}