Skip to content
Snippets Groups Projects
Commit 873cdc31 authored by lukas leufen's avatar lukas leufen
Browse files

updated tests

parent d4abd2ac
No related branches found
No related tags found
2 merge requests!17update to v0.4.0,!16handle station type
Pipeline #26582 passed
...@@ -42,7 +42,7 @@ class DataGenerator(keras.utils.Sequence): ...@@ -42,7 +42,7 @@ class DataGenerator(keras.utils.Sequence):
display all class attributes display all class attributes
""" """
return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \ return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \
f"variables={self.variables}, station_type='{self.station_type}', " \ f"variables={self.variables}, station_type={self.station_type}, " \
f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \ f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \
f"target_var='{self.target_var}', **{self.kwargs})" f"target_var='{self.target_var}', **{self.kwargs})"
......
...@@ -89,9 +89,11 @@ class DataPrep(object): ...@@ -89,9 +89,11 @@ class DataPrep(object):
self.data = self.check_for_negative_concentrations(data) self.data = self.check_for_negative_concentrations(data)
def check_station_type(self): def check_station_type(self):
"""
Search for the `station_type` entry in meta data and compare the value with the requested station_type. Raise
an EmptyQueryResult error if the values mismatch.
"""
if self.meta.at["station_type", self.station[0]] != self.station_type: if self.meta.at["station_type", self.station[0]] != self.station_type:
self.data = None
self.meta = None
raise join.EmptyQueryResult raise join.EmptyQueryResult
def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]:
......
...@@ -21,6 +21,7 @@ class TestDataGenerator: ...@@ -21,6 +21,7 @@ class TestDataGenerator:
assert gen.network == 'UBA' assert gen.network == 'UBA'
assert gen.stations == ['DEBW107'] assert gen.stations == ['DEBW107']
assert gen.variables == ['o3', 'temp'] assert gen.variables == ['o3', 'temp']
assert gen.station_type is None
assert gen.interpolate_dim == 'datetime' assert gen.interpolate_dim == 'datetime'
assert gen.target_dim == 'variables' assert gen.target_dim == 'variables'
assert gen.target_var == 'o3' assert gen.target_var == 'o3'
...@@ -34,7 +35,7 @@ class TestDataGenerator: ...@@ -34,7 +35,7 @@ class TestDataGenerator:
def test_repr(self, gen): def test_repr(self, gen):
path = os.path.join(os.path.dirname(__file__), 'data') path = os.path.join(os.path.dirname(__file__), 'data')
assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\ assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
f"variables=['o3', 'temp'], interpolate_dim='datetime', " \ f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \
f"target_dim='variables', target_var='o3', **{{}})".rstrip() f"target_dim='variables', target_var='o3', **{{}})".rstrip()
def test_len(self, gen): def test_len(self, gen):
......
...@@ -22,7 +22,7 @@ class TestDataPrep: ...@@ -22,7 +22,7 @@ class TestDataPrep:
assert data.station == ['DEBW107'] assert data.station == ['DEBW107']
assert data.variables == ['o3', 'temp'] assert data.variables == ['o3', 'temp']
assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'} assert data.statistics_per_var == {'o3': 'dma8eu', 'temp': 'maximum'}
assert not all([data.mean, data.std, data.history, data.label]) assert not all([data.mean, data.std, data.history, data.label, data.station_type])
assert {'test': 'testKWARGS'}.items() <= data.kwargs.items() assert {'test': 'testKWARGS'}.items() <= data.kwargs.items()
def test_init_no_stats(self): def test_init_no_stats(self):
...@@ -35,9 +35,10 @@ class TestDataPrep: ...@@ -35,9 +35,10 @@ class TestDataPrep:
d.network = 'dummy' d.network = 'dummy'
d.station = ['DEBW107'] d.station = ['DEBW107']
d.variables = ['o3', 'temp'] d.variables = ['o3', 'temp']
d.station_type = "traffic"
d.kwargs = None d.kwargs = None
assert d.__repr__().rstrip() == "Dataprep(path='data/test', network='dummy', station=['DEBW107'], "\ assert d.__repr__().rstrip() == "Dataprep(path='data/test', network='dummy', station=['DEBW107'], "\
"variables=['o3', 'temp'], **None)".rstrip() "variables=['o3', 'temp'], station_type='traffic', **None)".rstrip()
def test_set_file_name_and_meta(self): def test_set_file_name_and_meta(self):
d = object.__new__(DataPrep) d = object.__new__(DataPrep)
......
...@@ -48,7 +48,7 @@ class TestExperimentSetup: ...@@ -48,7 +48,7 @@ class TestExperimentSetup:
# experiment setup # experiment setup
assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("data_path", "general") == prepare_host()
assert data_store.get("trainable", "general") is False assert data_store.get("trainable", "general") is False
assert data_store.get("fraction_of_train", "general") == 0.8 assert data_store.get("fraction_of_training", "general") == 0.8
# set experiment name # set experiment name
assert data_store.get("experiment_name", "general") == "TestExperiment" assert data_store.get("experiment_name", "general") == "TestExperiment"
path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment")) path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment"))
...@@ -67,6 +67,7 @@ class TestExperimentSetup: ...@@ -67,6 +67,7 @@ class TestExperimentSetup:
'DEBW052', 'DEBW034', 'DEBY088', ] 'DEBW052', 'DEBW034', 'DEBY088', ]
assert data_store.get("stations", "general") == default_stations assert data_store.get("stations", "general") == default_stations
assert data_store.get("network", "general") == "AIRBASE" assert data_store.get("network", "general") == "AIRBASE"
assert data_store.get("station_type", "general") is None
assert data_store.get("variables", "general") == list(default_var_all_dict.keys()) assert data_store.get("variables", "general") == list(default_var_all_dict.keys())
assert data_store.get("statistics_per_var", "general") == default_var_all_dict assert data_store.get("statistics_per_var", "general") == default_var_all_dict
assert data_store.get("start", "general") == "1997-01-01" assert data_store.get("start", "general") == "1997-01-01"
...@@ -97,7 +98,8 @@ class TestExperimentSetup: ...@@ -97,7 +98,8 @@ class TestExperimentSetup:
experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) experiment_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
kwargs = dict(parser_args={"experiment_date": "TODAY"}, kwargs = dict(parser_args={"experiment_date": "TODAY"},
var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'}, var_all_dict={'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum'},
stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", variables=["o3", "temp"], stations=['DEBY053', 'DEBW059', 'DEBW027'], network="INTERNET", station_type="background",
variables=["o3", "temp"],
statistics_per_var=None, start="1999-01-01", end="2001-01-01", window_history=4, statistics_per_var=None, start="1999-01-01", end="2001-01-01", window_history=4,
target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1", target_var="temp", target_dim="target", window_lead_time=10, dimensions="dim1",
interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01", interpolate_dim="int_dim", interpolate_method="cubic", limit_nan_fill=5, train_start="2000-01-01",
...@@ -109,7 +111,7 @@ class TestExperimentSetup: ...@@ -109,7 +111,7 @@ class TestExperimentSetup:
# experiment setup # experiment setup
assert data_store.get("data_path", "general") == prepare_host() assert data_store.get("data_path", "general") == prepare_host()
assert data_store.get("trainable", "general") is True assert data_store.get("trainable", "general") is True
assert data_store.get("fraction_of_train", "general") == 0.5 assert data_store.get("fraction_of_training", "general") == 0.5
# set experiment name # set experiment name
assert data_store.get("experiment_name", "general") == "TODAY_network/" assert data_store.get("experiment_name", "general") == "TODAY_network/"
path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder")) path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder"))
...@@ -119,6 +121,7 @@ class TestExperimentSetup: ...@@ -119,6 +121,7 @@ class TestExperimentSetup:
'temp': 'maximum'} 'temp': 'maximum'}
assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027'] assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027']
assert data_store.get("network", "general") == "INTERNET" assert data_store.get("network", "general") == "INTERNET"
assert data_store.get("station_type", "general") == "background"
assert data_store.get("variables", "general") == ["o3", "temp"] assert data_store.get("variables", "general") == ["o3", "temp"]
assert data_store.get("statistics_per_var", "general") == {'o3': 'dma8eu', 'relhum': 'average_values', assert data_store.get("statistics_per_var", "general") == {'o3': 'dma8eu', 'relhum': 'average_values',
'temp': 'maximum'} 'temp': 'maximum'}
......
import logging import logging
import pytest import pytest
from src.helpers import PyTestRegex, TimeTracking from src.helpers import PyTestRegex
from src.modules.experiment_setup import ExperimentSetup from src.modules.experiment_setup import ExperimentSetup
from src.modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST from src.modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST
from src.data_generator import DataGenerator from src.data_generator import DataGenerator
...@@ -29,8 +29,8 @@ class TestPreProcessing: ...@@ -29,8 +29,8 @@ class TestPreProcessing:
@pytest.fixture @pytest.fixture
def obj_with_exp_setup(self): def obj_with_exp_setup(self):
ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}, station_type="background")
pre = object.__new__(PreProcessing) pre = object.__new__(PreProcessing)
super(PreProcessing, pre).__init__() super(PreProcessing, pre).__init__()
yield pre yield pre
...@@ -73,8 +73,8 @@ class TestPreProcessing: ...@@ -73,8 +73,8 @@ class TestPreProcessing:
def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup):
caplog.set_level(logging.DEBUG) caplog.set_level(logging.DEBUG)
obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") obj_with_exp_setup.create_set_split(slice(0, 2), "awesome")
assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=5): ['DEBW107', 'DEBY081', 'DEBW013', " assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', "
"'DEBW076', 'DEBW087']") "'DEBW076', 'DEBW087', 'DEBW001']")
data_store = obj_with_exp_setup.data_store data_store = obj_with_exp_setup.data_store
assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator)
with pytest.raises(NameNotFoundInScope): with pytest.raises(NameNotFoundInScope):
...@@ -88,9 +88,11 @@ class TestPreProcessing: ...@@ -88,9 +88,11 @@ class TestPreProcessing:
kwargs = pre._create_args_dict(DEFAULT_KWARGS_LIST) kwargs = pre._create_args_dict(DEFAULT_KWARGS_LIST)
stations = pre.data_store.get("stations", "general") stations = pre.data_store.get("stations", "general")
valid_stations = pre.check_valid_stations(args, kwargs, stations) valid_stations = pre.check_valid_stations(args, kwargs, stations)
assert valid_stations == stations assert len(valid_stations) < len(stations)
assert valid_stations == stations[:-1]
assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started') assert caplog.record_tuples[0] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\)')) assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 6 station\(s\). Found '
r'5/6 valid stations.'))
def test_split_set_indices(self, obj_no_init): def test_split_set_indices(self, obj_no_init):
dummy_list = list(range(0, 15)) dummy_list = list(range(0, 15))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment