Skip to content
Snippets Groups Projects
Commit f8ded3c6 authored by lukas leufen's avatar lukas leufen
Browse files

further change to investigate

parent 6a34bebb
No related branches found
No related tags found
2 merge requests!17update to v0.4.0,!16handle station type
Pipeline #26599 passed
......@@ -13,12 +13,12 @@ class TestDataGenerator:
@pytest.fixture
def gen(self):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'UBA', 'DEBW107', ['o3', 'temp'],
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
'datetime', 'variables', 'o3')
def test_init(self, gen):
assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data')
assert gen.network == 'UBA'
assert gen.network == 'AIRBASE'
assert gen.stations == ['DEBW107']
assert gen.variables == ['o3', 'temp']
assert gen.station_type is None
......@@ -34,7 +34,7 @@ class TestDataGenerator:
def test_repr(self, gen):
path = os.path.join(os.path.dirname(__file__), 'data')
assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], "\
f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \
f"target_dim='variables', target_var='o3', **{{}})".rstrip()
......@@ -43,17 +43,6 @@ class TestDataGenerator:
gen.stations = ['station1', 'station2', 'station3']
assert len(gen) == 3
def test_iter(self, gen):
assert hasattr(gen, '_iterator') is False
iter(gen)
assert hasattr(gen, '_iterator')
assert gen._iterator == 0
def test_next(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
for i, d in enumerate(gen, start=1):
assert i == gen._iterator
def test_getitem(self, caplog, gen):
caplog.set_level(logging.DEBUG)
print(gen)
......@@ -67,6 +56,17 @@ class TestDataGenerator:
assert station[1].data.shape[-1] == gen.window_lead_time
assert station[0].data.shape[1] == gen.window_history + 1
def test_iter(self, gen):
assert hasattr(gen, '_iterator') is False
iter(gen)
assert hasattr(gen, '_iterator')
assert gen._iterator == 0
def test_next(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
for i, d in enumerate(gen, start=1):
assert i == gen._iterator
def test_get_station_key(self, gen):
gen.stations.append("DEBW108")
f = gen.get_station_key
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment