Select Git revision
bayes_optimize.ipynb
test_pre_processing.py 9.77 KiB
import logging
import pytest
import mock
from mlair.data_handler import DefaultDataHandler, DataCollection, AbstractDataHandler
from mlair.helpers.datastore import NameNotFoundInScope
from mlair.helpers import PyTestRegex
from mlair.run_modules.experiment_setup import ExperimentSetup
from mlair.run_modules.pre_processing import PreProcessing
from mlair.run_modules.run_environment import RunEnvironment
import multiprocessing
class TestPreProcessing:
@pytest.fixture
def obj_super_init(self):
obj = object.__new__(PreProcessing)
super(PreProcessing, obj).__init__()
obj.data_store.set("NAME1", 1, "general")
obj.data_store.set("NAME2", 2, "general")
obj.data_store.set("NAME3", 3, "general")
obj.data_store.set("NAME1", 10, "general.sub")
obj.data_store.set("NAME4", 4, "general.sub.sub")
yield obj
RunEnvironment().__del__()
@pytest.fixture
def obj_with_exp_setup(self):
ExperimentSetup(stations=['DEBW107', 'DEBW013', 'DEBW087', 'DEBW99X'],
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background",
data_origin={'o3': 'UBA', 'temp': 'UBA'}, data_handler=DefaultDataHandler)
pre = object.__new__(PreProcessing)
super(PreProcessing, pre).__init__()
yield pre
RunEnvironment().__del__()
def test_init(self, caplog):
ExperimentSetup(stations=['DEBW087', 'DEBW107', 'DEBW013'],
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'},
data_origin={'o3': 'UBA', 'temp': 'UBA'}
)
caplog.clear()
caplog.set_level(logging.INFO)
with PreProcessing():
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)')
assert caplog.record_tuples[-6] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 3 '
r'station\(s\). Found 3/3 valid stations.'))
assert caplog.record_tuples[-5] == ('root', 20, "use serial create_info_df (train)")
assert caplog.record_tuples[-4] == ('root', 20, "use serial create_info_df (val)")
assert caplog.record_tuples[-3] == ('root', 20, "use serial create_info_df (test)")
assert caplog.record_tuples[-2] == ('root', 20, "Searching for competitors to be prepared for use.")
assert caplog.record_tuples[-1] == ('root', 20, "No preparation required for competitor ols as no specific "
"instruction is provided.")
RunEnvironment().__del__()
def test_init_multiple_stat_mix(self, caplog):
ExperimentSetup(stations=['DEBW087', 'DEBW107', 'DEBW013'],
statistics_per_var={'o3': ['dma8eu', 'perc95'], 'temp': 'maximum'},
data_origin={'o3': 'UBA', 'temp': 'UBA'}
)
caplog.clear()
caplog.set_level(logging.INFO)
with PreProcessing():
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)')
assert caplog.record_tuples[-6] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 3 '
r'station\(s\). Found 3/3 valid stations.'))
assert caplog.record_tuples[-5] == ('root', 20, "use serial create_info_df (train)")
assert caplog.record_tuples[-4] == ('root', 20, "use serial create_info_df (val)")
assert caplog.record_tuples[-3] == ('root', 20, "use serial create_info_df (test)")
assert caplog.record_tuples[-2] == ('root', 20, "Searching for competitors to be prepared for use.")
assert caplog.record_tuples[-1] == ('root', 20, "No preparation required for competitor ols as no specific "
"instruction is provided.")
RunEnvironment().__del__()
def test_run(self, obj_with_exp_setup):
assert obj_with_exp_setup.data_store.search_name("data_collection") == []
assert obj_with_exp_setup._run() is None
assert obj_with_exp_setup.data_store.search_name("data_collection") == sorted(["general.train", "general.val",
"general.train_val",
"general.test"])
def test_split_train_val_test(self, obj_with_exp_setup):
assert obj_with_exp_setup.data_store.search_name("data_collection") == []
obj_with_exp_setup.split_train_val_test()
data_store = obj_with_exp_setup.data_store
expected_params = ["data_collection", "start", "end", "stations", "permute_data", "min_length",
"extreme_values", "extremes_on_right_tail_only", "upsampling"]
assert data_store.search_scope("general.train") == sorted(expected_params)
assert data_store.search_name("data_collection") == sorted(["general.train", "general.val", "general.test",
"general.train_val"])
def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup):
caplog.set_level(logging.DEBUG)
obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general")
obj_with_exp_setup.create_set_split(slice(0, 2), "awesome")
assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBW013']") in caplog.record_tuples
data_store = obj_with_exp_setup.data_store
assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection)
with pytest.raises(NameNotFoundInScope):
data_store.get("data_collection", "general")
assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBW013"]
def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup):
caplog.set_level(logging.DEBUG)
obj_with_exp_setup.create_set_split(slice(0, 2), "awesome")
message = "Awesome stations (len=4): ['DEBW107', 'DEBW013', 'DEBW087', 'DEBW99X']"
assert ('root', 10, message) in caplog.record_tuples
data_store = obj_with_exp_setup.data_store
assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection)
with pytest.raises(NameNotFoundInScope):
data_store.get("data_collection", "general")
assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBW013', 'DEBW087']
@pytest.mark.parametrize("name", (None, "tester"))
def test_validate_station_serial(self, caplog, obj_with_exp_setup, name):
pre = obj_with_exp_setup
caplog.set_level(logging.INFO)
stations = pre.data_store.get("stations", "general")
data_preparation = pre.data_store.get("data_handler")
collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=name)
assert isinstance(collection, DataCollection)
assert len(valid_stations) < len(stations)
assert valid_stations == stations[:-1]
expected = "check valid stations started" + ' (%s)' % (name if name else 'all')
assert caplog.record_tuples[0] == ('root', 20, expected)
assert caplog.record_tuples[1] == ('root', 20, "use serial validate station approach")
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 4 '
r'station\(s\). Found 3/4 valid stations.'))
@mock.patch("psutil.cpu_count", return_value=3)
@mock.patch("multiprocessing.Pool", return_value=multiprocessing.Pool(3))
def test_validate_station_parallel(self, mock_pool, mock_cpu, caplog, obj_with_exp_setup):
pre = obj_with_exp_setup
caplog.clear()
caplog.set_level(logging.INFO)
stations = pre.data_store.get("stations", "general")
data_preparation = pre.data_store.get("data_handler")
collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=None)
assert isinstance(collection, DataCollection)
assert len(valid_stations) < len(stations)
assert valid_stations == stations[:-1]
assert caplog.record_tuples[0] == ('root', 20, "check valid stations started (all)")
assert caplog.record_tuples[1] == ('root', 20, "use parallel validate station approach")
assert caplog.record_tuples[2] == ('root', 20, "running 3 processes in parallel")
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 4 '
r'station\(s\). Found 3/4 valid stations.'))
def test_split_set_indices(self, obj_super_init):
dummy_list = list(range(0, 15))
train, val, test, train_val = obj_super_init.split_set_indices(len(dummy_list), 0.9)
assert dummy_list[train] == list(range(0, 10))
assert dummy_list[val] == list(range(10, 13))
assert dummy_list[test] == list(range(13, 15))
assert dummy_list[train_val] == list(range(0, 13))
def test_transformation(self):
pre = object.__new__(PreProcessing)
data_preparation = AbstractDataHandler
stations = ['DEBW107', 'DEBY081']
assert pre.transformation(data_preparation, stations) is None
class data_preparation_no_trans: pass
assert pre.transformation(data_preparation_no_trans, stations) is None