diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c34d9a24bdfacd3795dd2f64bdbd8017ea8ba71e..f0da24d54dd183d0a4d40667ea0a9619b94f1e6e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -47,7 +47,7 @@ tests: when: always paths: - badges/ - - test/ + - test_results/ coverage: tags: @@ -90,7 +90,7 @@ pages: - cp -af coverage/. public/coverage - ls public/coverage - mkdir -p public/test - - cp -af test/. public/test + - cp -af test_results/. public/test - ls public/test - ls public when: always @@ -101,7 +101,7 @@ pages: - public - badges/ - coverage/ - - test/ + - test_results/ cache: key: old-pages paths: diff --git a/CI/run_pytest.sh b/CI/run_pytest.sh index d8755448cfaaf9c1477add0929ecc65cc1115ba4..5547b7ab2715e59c123056e56def989bdefdcfeb 100644 --- a/CI/run_pytest.sh +++ b/CI/run_pytest.sh @@ -6,14 +6,14 @@ python3 -m pytest --html=report.html --self-contained-html test/ | tee test_resu IS_FAILED=$? # move html test report -mkdir test/ +mkdir test_results/ BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}") -mkdir test/${BRANCH_NAME} -mkdir test/recent -cp report.html test/${BRANCH_NAME}/. -cp report.html test/recent/. +mkdir test_results/${BRANCH_NAME} +mkdir test_results/recent +cp report.html test_results/${BRANCH_NAME}/. +cp report.html test_results/recent/. if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then - cp -r report.html test/. + cp -r report.html test_results/. fi # exit 0 if no tests implemented diff --git a/src/data_generator.py b/src/data_generator.py index 3d8a1c7c242da3d45a1b17361e210a016a419dd6..860791235f111a7ffb151f2b06424be76dc8eba7 100644 --- a/src/data_generator.py +++ b/src/data_generator.py @@ -19,9 +19,9 @@ class DataGenerator(keras.utils.Sequence): """ def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], - interpolate_dim: str, target_dim: str, target_var: str, interpolate_method: str = "linear", - limit_nan_fill: int = 1, window_history: int = 7, window_lead_time: int = 4, - transform_method: str = "standardise", **kwargs): + interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, + interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history: int = 7, + window_lead_time: int = 4, transform_method: str = "standardise", **kwargs): self.data_path = os.path.abspath(data_path) self.network = network self.stations = helpers.to_list(stations) @@ -29,6 +29,7 @@ class DataGenerator(keras.utils.Sequence): self.interpolate_dim = interpolate_dim self.target_dim = target_dim self.target_var = target_var + self.station_type = station_type self.interpolate_method = interpolate_method self.limit_nan_fill = limit_nan_fill self.window_history = window_history @@ -41,8 +42,9 @@ class DataGenerator(keras.utils.Sequence): display all class attributes """ return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \ - f"variables={self.variables}, interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}'" \ - f", target_var='{self.target_var}', **{self.kwargs})" + f"variables={self.variables}, station_type={self.station_type}, " \ + f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \ + f"target_var='{self.target_var}', **{self.kwargs})" def __len__(self): """ @@ -94,7 +96,8 @@ class DataGenerator(keras.utils.Sequence): :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) - data = DataPrep(self.data_path, self.network, station, self.variables, **self.kwargs) + data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) data.transform("datetime", method=self.transform_method) data.make_history_window(self.interpolate_dim, self.window_history) diff --git a/src/data_preparation.py b/src/data_preparation.py index 873433f499f51c003988d8b33da7a525d14544fa..3c50ba893563780dfd8ac92f36fffabc38ed16a9 100644 --- a/src/data_preparation.py +++ b/src/data_preparation.py @@ -44,11 +44,13 @@ class DataPrep(object): """ - def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str], **kwargs): + def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str], + station_type: str = None, **kwargs): self.path = os.path.abspath(path) self.network = network self.station = helpers.to_list(station) self.variables = variables + self.station_type = station_type self.mean = None self.std = None self.history = None @@ -75,14 +77,36 @@ class DataPrep(object): file_name = self._set_file_name() meta_file = self._set_meta_file_name() try: + + logging.debug(f"try to load local data from: {file_name}") data = self._slice_prep(xr.open_dataarray(file_name)) self.data = self.check_for_negative_concentrations(data) self.meta = pd.read_csv(meta_file, index_col=0) + if self.station_type is not None: + self.check_station_meta() + logging.debug("loading finished") except FileNotFoundError as e: logging.warning(e) data, self.meta = self.download_data_from_join(file_name, meta_file) data = self._slice_prep(data) self.data = self.check_for_negative_concentrations(data) + logging.debug("loaded new data from JOIN") + + def check_station_meta(self): + """ + Search for the entries in meta data and compare the value with the requested values. Raise a FileNotFoundError + if the values mismatch. + """ + check_dict = { + "station_type": self.station_type, + "network_name": self.network + } + for (k, v) in check_dict.items(): + if self.meta.at[k, self.station[0]] != v: + logging.debug(f"meta data does not agree which given request for {k}: {v} (requested) != " + f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " + f"grapping from web.") + raise FileNotFoundError def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: """ @@ -92,7 +116,8 @@ class DataPrep(object): :return: """ df_all = {} - df, meta = join.download_join(station_name=self.station, statvar=self.statistics_per_var) + df, meta = join.download_join(station_name=self.station, statvar=self.statistics_per_var, + station_type=self.station_type, network_name=self.network) df_all[self.station[0]] = df # convert df_all to xarray xarr = {k: xr.DataArray(v, dims=['datetime', 'variables']) for k, v in df_all.items()} @@ -111,7 +136,7 @@ class DataPrep(object): def __repr__(self): return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \ - f"variables={self.variables}, **{self.kwargs})" + f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})" def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, **kwargs): diff --git a/src/join.py b/src/join.py index 2b13dcf41c5bc03e9dba274fd5e643c79b091cde..4f9f36f960bc5a757a70b39222fb183ccec7aa8f 100644 --- a/src/join.py +++ b/src/join.py @@ -3,7 +3,6 @@ __date__ = '2019-10-16' import requests -import json import logging import pandas as pd import datetime as dt @@ -13,12 +12,21 @@ from src import helpers join_url_base = 'https://join.fz-juelich.de/services/rest/surfacedata/' -def download_join(station_name: Union[str, List[str]], statvar: dict) -> [pd.DataFrame, pd.DataFrame]: +class EmptyQueryResult(Exception): + """ + Exception that get raised if a query to JOIN returns empty results. + """ + pass + + +def download_join(station_name: Union[str, List[str]], statvar: dict, station_type: str = None, network_name: str = None) -> [pd.DataFrame, pd.DataFrame]: """ read data from JOIN/TOAR :param station_name: Station name e.g. DEBY122 :param statvar: key as variable like 'O3', values as statistics on keys like 'mean' + :param station_type: + :param network_name: :returns: - df - pandas df with all variables and statistics - meta - pandas df with all meta information @@ -27,7 +35,8 @@ def download_join(station_name: Union[str, List[str]], statvar: dict) -> [pd.Dat station_name = helpers.to_list(station_name) # load series information - opts = {'base': join_url_base, 'service': 'series', 'station_id': station_name[0]} + opts = {"base": join_url_base, "service": "series", "station_id": station_name[0], "station_type": station_type, + "network_name": network_name} url = create_url(**opts) response = requests.get(url) station_vars = response.json() @@ -65,7 +74,7 @@ def download_join(station_name: Union[str, List[str]], statvar: dict) -> [pd.Dat meta.columns = station_name return df, meta else: - raise ValueError("No data found in JOIN.") + raise EmptyQueryResult("No data found in JOIN.") def _correct_stat_name(stat: str) -> str: @@ -97,7 +106,7 @@ def create_url(base: str, service: str, **kwargs: Union[str, int, float]) -> str :param kwargs: keyword pairs for optional request specifications, e.g. 'statistics=maximum' :return: combined url as string """ - url = '{}{}/?'.format(base, service) + '&'.join('{}={}'.format(k, v) for k, v in kwargs.items()) + url = '{}{}/?'.format(base, service) + '&'.join('{}={}'.format(k, v) for k, v in kwargs.items() if v is not None) return url diff --git a/src/modules/experiment_setup.py b/src/modules/experiment_setup.py index a76fe60b34b679b5702ec85a11f95002c3c6fe34..a20f0b83e9828550d2f717502b5371c2c1ad7e9a 100644 --- a/src/modules/experiment_setup.py +++ b/src/modules/experiment_setup.py @@ -27,7 +27,7 @@ class ExperimentSetup(RunEnvironment): trainable: Train new model if true, otherwise try to load existing model """ - def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, variables=None, + def __init__(self, parser_args=None, var_all_dict=None, stations=None, network=None, station_type=None, variables=None, statistics_per_var=None, start=None, end=None, window_history=None, target_var="o3", target_dim=None, window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, @@ -53,6 +53,7 @@ class ExperimentSetup(RunEnvironment): self._set_param("var_all_dict", var_all_dict, default=DEFAULT_VAR_ALL_DICT) self._set_param("stations", stations, default=DEFAULT_STATIONS) self._set_param("network", network, default="AIRBASE") + self._set_param("station_type", station_type, default=None) self._set_param("variables", variables, default=list(self.data_store.get("var_all_dict", "general").keys())) self._set_param("statistics_per_var", statistics_per_var, default=self.data_store.get("var_all_dict", "general")) self._set_param("start", start, default="1997-01-01", scope="general") diff --git a/src/modules/modules.py b/src/modules/modules.py index 8532e1d812a5de7a3b47423d9f9bb3c9bcd43abc..033fd0779d8d140e684103b27fc7c025dedcdb81 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -1,8 +1,9 @@ import logging -# from src.experiment_setup import ExperimentSetup import argparse from src.modules.run_environment import RunEnvironment +from src.modules.experiment_setup import ExperimentSetup +from src.modules.pre_processing import PreProcessing class Training(RunEnvironment): @@ -28,6 +29,7 @@ if __name__ == "__main__": parser.add_argument('--experiment_date', metavar='--exp_date', type=str, nargs=1, default=None, help="set experiment date as string") parser_args = parser.parse_args() - # with run(): - # setup = ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087']) - # PreProcessing(setup) + with RunEnvironment(): + ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], + station_type='background') + PreProcessing() diff --git a/src/modules/pre_processing.py b/src/modules/pre_processing.py index d999217e9f903d3d67a24179c9f3654fee3e60d4..d3056f52bd0a60e0c9e7ed97fa593f3b596898a4 100644 --- a/src/modules/pre_processing.py +++ b/src/modules/pre_processing.py @@ -5,10 +5,11 @@ from src.data_generator import DataGenerator from src.helpers import TimeTracking from src.modules.run_environment import RunEnvironment from src.datastore import NameNotFoundInDataStore, NameNotFoundInScope +from src.join import EmptyQueryResult DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] -DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history", "window_lead_time", "statistics_per_var"] +DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history", "window_lead_time", "statistics_per_var", "station_type"] class PreProcessing(RunEnvironment): @@ -110,7 +111,8 @@ class PreProcessing(RunEnvironment): valid_stations.append(station) logging.debug(f"{station}: history_shape = {history.shape}") logging.debug(f"{station}: loading time = {t_inner}") - except AttributeError: + except (AttributeError, EmptyQueryResult): continue - logging.info(f"run for {t_outer} to check {len(all_stations)} station(s)") + logging.info(f"run for {t_outer} to check {len(all_stations)} station(s). Found {len(valid_stations)}/" + f"{len(all_stations)} valid stations.") return valid_stations diff --git a/test/test_data_generator.py b/test/test_data_generator.py index 7c745782bc057060dd439af1fb6e03c5b3ef5730..2fe8b8c0b5a7f4f8be9626b0061702acb53ecb6b 100644 --- a/test/test_data_generator.py +++ b/test/test_data_generator.py @@ -1,26 +1,21 @@ import pytest import os from src.data_generator import DataGenerator -import logging -import numpy as np -import xarray as xr -import datetime as dt -import pandas as pd -from operator import itemgetter class TestDataGenerator: @pytest.fixture def gen(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'UBA', 'DEBW107', ['o3', 'temp'], + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3') def test_init(self, gen): assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data') - assert gen.network == 'UBA' + assert gen.network == 'AIRBASE' assert gen.stations == ['DEBW107'] assert gen.variables == ['o3', 'temp'] + assert gen.station_type is None assert gen.interpolate_dim == 'datetime' assert gen.target_dim == 'variables' assert gen.target_var == 'o3' @@ -33,8 +28,8 @@ class TestDataGenerator: def test_repr(self, gen): path = os.path.join(os.path.dirname(__file__), 'data') - assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\ - f"variables=['o3', 'temp'], interpolate_dim='datetime', " \ + assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], "\ + f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \ f"target_dim='variables', target_var='o3', **{{}})".rstrip() def test_len(self, gen): @@ -42,6 +37,15 @@ class TestDataGenerator: gen.stations = ['station1', 'station2', 'station3'] assert len(gen) == 3 + def test_getitem(self, gen): + gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} + station = gen["DEBW107"] + assert len(station) == 2 + assert station[0].Stations.data == "DEBW107" + assert station[0].data.shape[1:] == (8, 1, 2) + assert station[1].data.shape[-1] == gen.window_lead_time + assert station[0].data.shape[1] == gen.window_history + 1 + def test_iter(self, gen): assert hasattr(gen, '_iterator') is False iter(gen) @@ -53,15 +57,6 @@ class TestDataGenerator: for i, d in enumerate(gen, start=1): assert i == gen._iterator - def test_getitem(self, gen): - gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} - station = gen["DEBW107"] - assert len(station) == 2 - assert station[0].Stations.data == "DEBW107" - assert station[0].data.shape[1:] == (8, 1, 2) - assert station[1].data.shape[-1] == gen.window_lead_time - assert station[0].data.shape[1] == gen.window_history + 1 - def test_get_station_key(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key diff --git a/test/test_data_preparation.py b/test/test_data_preparation.py index 5d45c041b6e669cced56172d41fc2f9653dd30e7..30f93e6d734885252d2c7a438d6065aa680f32f8 100644 --- a/test/test_data_preparation.py +++ b/test/test_data_preparation.py @@ -1,6 +1,7 @@ import pytest import os from src.data_preparation import DataPrep +from src.join import EmptyQueryResult import logging import numpy as np import xarray as xr @@ -13,16 +14,18 @@ class TestDataPrep: @pytest.fixture def data(self): - return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'dummy', 'DEBW107', ['o3', 'temp'], - test='testKWARGS', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + return DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], + station_type='background', test='testKWARGS', + statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) def test_init(self, data): assert data.path == os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') - assert data.network == 'dummy' + assert data.network == 'AIRBASE' assert data.station == ['DEBW107'] assert data.variables == ['o3', 'temp'] + assert data.station_type == "background" assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'} - assert not all([data.mean, data.std, data.history, data.label]) + assert not all([data.mean, data.std, data.history, data.label, data.station_type]) assert {'test': 'testKWARGS'}.items() <= data.kwargs.items() def test_init_no_stats(self): @@ -35,9 +38,10 @@ class TestDataPrep: d.network = 'dummy' d.station = ['DEBW107'] d.variables = ['o3', 'temp'] + d.station_type = "traffic" d.kwargs = None assert d.__repr__().rstrip() == "Dataprep(path='data/test', network='dummy', station=['DEBW107'], "\ - "variables=['o3', 'temp'], **None)".rstrip() + "variables=['o3', 'temp'], station_type=traffic, **None)".rstrip() def test_set_file_name_and_meta(self): d = object.__new__(DataPrep) @@ -229,3 +233,9 @@ class TestDataPrep: assert res.sel({'variables': 'o3'}).min() >= 0 res = data.check_for_negative_concentrations(data.data, minimum=2) assert res.sel({'variables': 'o3'}).min() >= 2 + + def test_check_station(self, data): + with pytest.raises(EmptyQueryResult): + data_new = DataPrep(os.path.join(os.path.dirname(__file__), 'data'), 'dummy', 'DEBW107', ['o3', 'temp'], + station_type='traffic', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py index 832ff45a0a3b1384e1300c0fa38ed3d1ec2204b8..be3db59e2415f28ea63d42b7cc6ced6b2c095700 100644 --- a/test/test_modules/test_experiment_setup.py +++ b/test/test_modules/test_experiment_setup.py @@ -48,7 +48,7 @@ class TestExperimentSetup: # experiment setup assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("trainable", "general") is False - assert data_store.get("fraction_of_train", "general") == 0.8 + assert data_store.get("fraction_of_training", "general") == 0.8 # set experiment name assert data_store.get("experiment_name", "general") == "TestExperiment" path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment")) @@ -67,6 +67,7 @@ class TestExperimentSetup: 'DEBW052', 'DEBW034', 'DEBY088', ] assert data_store.get("stations", "general") == default_stations assert data_store.get("network", "general") == "AIRBASE" + assert data_store.get("station_type", "general") is None assert data_store.get("variables", "general") == list(default_var_all_dict.keys()) assert data_store.get("statistics_per_var", "general") == default_var_all_dict assert data_store.get("start", "general") == "1997-01-01" @@ -97,7 +98,8 @@ class TestExperimentSetup: experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) kwargs = dict(parser_args={"experiment_date": "TODAY"}, var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'}, - stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", variables=["o3", "temp"], + stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", station_type="background", + variables=["o3", "temp"], statistics_per_var=None, start="1999-01-01", end="2001-01-01", window_history=4, target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1", interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01", @@ -109,7 +111,7 @@ class TestExperimentSetup: # experiment setup assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("trainable", "general") is True - assert data_store.get("fraction_of_train", "general") == 0.5 + assert data_store.get("fraction_of_training", "general") == 0.5 # set experiment name assert data_store.get("experiment_name", "general") == "TODAY_network/" path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) @@ -119,6 +121,7 @@ class TestExperimentSetup: 'temp': 'maximum'} assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027'] assert data_store.get("network", "general") == "INTERNET" + assert data_store.get("station_type", "general") == "background" assert data_store.get("variables", "general") == ["o3", "temp"] assert data_store.get("statistics_per_var", "general") == {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'} diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index bc121885ddb8ee20b0f571e7f0250845c6e99e6a..1af910ee660510c5667e6170c82c079dcf515bb2 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -1,7 +1,8 @@ import logging import pytest +import time -from src.helpers import PyTestRegex, TimeTracking +from src.helpers import PyTestRegex from src.modules.experiment_setup import ExperimentSetup from src.modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST from src.data_generator import DataGenerator @@ -13,7 +14,8 @@ class TestPreProcessing: @pytest.fixture def obj_no_init(self): - return object.__new__(PreProcessing) + yield object.__new__(PreProcessing) + RunEnvironment().__del__() @pytest.fixture def obj_super_init(self): @@ -29,8 +31,8 @@ class TestPreProcessing: @pytest.fixture def obj_with_exp_setup(self): - ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], - var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) + ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], + var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background") pre = object.__new__(PreProcessing) super(PreProcessing, pre).__init__() yield pre @@ -43,7 +45,8 @@ class TestPreProcessing: PreProcessing() assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-2] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). Found ' + r'5/5 valid stations.')) RunEnvironment().__del__() def test_run(self, obj_with_exp_setup): @@ -73,8 +76,8 @@ class TestPreProcessing: def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") - assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=5): ['DEBW107', 'DEBY081', 'DEBW013', " - "'DEBW076', 'DEBW087']") + assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', " + "'DEBW076', 'DEBW087', 'DEBW001']") data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) with pytest.raises(NameNotFoundInScope): @@ -88,9 +91,11 @@ class TestPreProcessing: kwargs = pre._create_args_dict(DEFAULT_KWARGS_LIST) stations = pre.data_store.get("stations", "general") valid_stations = pre.check_valid_stations(args, kwargs, stations) - assert valid_stations == stations + assert len(valid_stations) < len(stations) + assert valid_stations == stations[:-1] assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found ' + r'5/6 valid stations.')) def test_split_set_indices(self, obj_no_init): dummy_list = list(range(0, 15))