import logging import pytest import mock from mlair.data_handler import DefaultDataHandler, DataCollection, AbstractDataHandler from mlair.helpers.datastore import NameNotFoundInScope from mlair.helpers import PyTestRegex from mlair.run_modules.experiment_setup import ExperimentSetup from mlair.run_modules.pre_processing import PreProcessing from mlair.run_modules.run_environment import RunEnvironment import pandas as pd import numpy as np import multiprocessing class TestPreProcessing: @pytest.fixture def obj_super_init(self): obj = object.__new__(PreProcessing) super(PreProcessing, obj).__init__() obj.data_store.set("NAME1", 1, "general") obj.data_store.set("NAME2", 2, "general") obj.data_store.set("NAME3", 3, "general") obj.data_store.set("NAME1", 10, "general.sub") obj.data_store.set("NAME4", 4, "general.sub.sub") yield obj RunEnvironment().__del__() @pytest.fixture def obj_with_exp_setup(self): ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background", data_handler=DefaultDataHandler) pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre RunEnvironment().__del__() def test_init(self, caplog): ExperimentSetup(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) caplog.clear() caplog.set_level(logging.INFO) with PreProcessing(): assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started (preprocessing)') assert caplog.record_tuples[-3] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 5 ' r'station\(s\). Found 5/5 valid stations.')) assert caplog.record_tuples[-2] == ('root', 20, "Searching for competitors to be prepared for use.") assert caplog.record_tuples[-1] == ( 'root', 20, "No preparation required because no competitor was provided " "to the workflow.") RunEnvironment().__del__() def test_run(self, obj_with_exp_setup): assert obj_with_exp_setup.data_store.search_name("data_collection") == [] assert obj_with_exp_setup._run() is None assert obj_with_exp_setup.data_store.search_name("data_collection") == sorted(["general.train", "general.val", "general.train_val", "general.test"]) def test_split_train_val_test(self, obj_with_exp_setup): assert obj_with_exp_setup.data_store.search_name("data_collection") == [] obj_with_exp_setup.split_train_val_test() data_store = obj_with_exp_setup.data_store expected_params = ["data_collection", "start", "end", "stations", "permute_data", "min_length", "extreme_values", "extremes_on_right_tail_only", "upsampling"] assert data_store.search_scope("general.train") == sorted(expected_params) assert data_store.search_name("data_collection") == sorted(["general.train", "general.val", "general.test", "general.train_val"]) def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general") obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection) with pytest.raises(NameNotFoundInScope): data_store.get("data_collection", "general") assert data_store.get("stations", "general.awesome") == ["DEBW107", "DEBY081"] def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']" assert ('root', 10, message) in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("data_collection", "general.awesome"), DataCollection) with pytest.raises(NameNotFoundInScope): data_store.get("data_collection", "general") assert data_store.get("stations", "general.awesome") == ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'] @pytest.mark.parametrize("name", (None, "tester")) def test_validate_station_serial(self, caplog, obj_with_exp_setup, name): pre = obj_with_exp_setup caplog.set_level(logging.INFO) stations = pre.data_store.get("stations", "general") data_preparation = pre.data_store.get("data_handler") collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=name) assert isinstance(collection, DataCollection) assert len(valid_stations) < len(stations) assert valid_stations == stations[:-1] expected = "check valid stations started" + ' (%s)' % (name if name else 'all') assert caplog.record_tuples[0] == ('root', 20, expected) assert caplog.record_tuples[1] == ('root', 20, "use serial validate station approach") assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' r'station\(s\). Found 5/6 valid stations.')) @mock.patch("psutil.cpu_count", return_value=3) @mock.patch("multiprocessing.Pool", return_value=multiprocessing.Pool(3)) def test_validate_station_parallel(self, mock_pool, mock_cpu, caplog, obj_with_exp_setup): pre = obj_with_exp_setup caplog.clear() caplog.set_level(logging.INFO) stations = pre.data_store.get("stations", "general") data_preparation = pre.data_store.get("data_handler") collection, valid_stations = pre.validate_station(data_preparation, stations, set_name=None) assert isinstance(collection, DataCollection) assert len(valid_stations) < len(stations) assert valid_stations == stations[:-1] assert caplog.record_tuples[0] == ('root', 20, "check valid stations started (all)") assert caplog.record_tuples[1] == ('root', 20, "use parallel validate station approach") assert caplog.record_tuples[2] == ('root', 20, "running 3 processes in parallel") assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+:\d+:\d+ \(hh:mm:ss\) to check 6 ' r'station\(s\). Found 5/6 valid stations.')) def test_split_set_indices(self, obj_super_init): dummy_list = list(range(0, 15)) train, val, test, train_val = obj_super_init.split_set_indices(len(dummy_list), 0.9) assert dummy_list[train] == list(range(0, 10)) assert dummy_list[val] == list(range(10, 13)) assert dummy_list[test] == list(range(13, 15)) assert dummy_list[train_val] == list(range(0, 13)) def test_transformation(self): pre = object.__new__(PreProcessing) data_preparation = AbstractDataHandler stations = ['DEBW107', 'DEBY081'] assert pre.transformation(data_preparation, stations) is None class data_preparation_no_trans: pass assert pre.transformation(data_preparation_no_trans, stations) is None # @pytest.fixture # def dummy_df(self): # data_dict = {'station_name': {'DEBW013': 'Stuttgart Bad Cannstatt', 'DEBW076': 'Baden-Baden', # 'DEBW087': 'Schwäbische_Alb', 'DEBW107': 'Tübingen', # 'DEBY081': 'Garmisch-Partenkirchen/Kreuzeckbahnstraße', '# Stations': np.nan, # '# Samples': np.nan}, # 'station_lon': {'DEBW013': 9.2297, 'DEBW076': 8.2202, 'DEBW087': 9.2076, 'DEBW107': 9.0512, # 'DEBY081': 11.0631, '# Stations': np.nan, '# Samples': np.nan}, # 'station_lat': {'DEBW013': 48.8088, 'DEBW076': 48.7731, 'DEBW087': 48.3458, 'DEBW107': 48.5077, # 'DEBY081': 47.4764, '# Stations': np.nan, '# Samples': np.nan}, # 'station_alt': {'DEBW013': 235.0, 'DEBW076': 148.0, 'DEBW087': 798.0, 'DEBW107': 325.0, # 'DEBY081': 735.0, '# Stations': np.nan, '# Samples': np.nan}, # 'train': {'DEBW013': 1413, 'DEBW076': 3002, 'DEBW087': 3016, 'DEBW107': 1782, 'DEBY081': 2837, # '# Stations': 6, '# Samples': 12050}, # 'val': {'DEBW013': 698, 'DEBW076': 715, 'DEBW087': 700, 'DEBW107': 701, 'DEBY081': 456, # '# Stations': 6, '# Samples': 3270}, # 'test': {'DEBW013': 1066, 'DEBW076': 696, 'DEBW087': 1080, 'DEBW107': 1080, 'DEBY081': 700, # '# Stations': 6, '# Samples': 4622}} # df = pd.DataFrame.from_dict(data_dict) # return df