Skip to content
Snippets Groups Projects
Commit f8ded3c6 authored by lukas leufen's avatar lukas leufen
Browse files

further change to investigate

parent 6a34bebb
Branches
Tags
2 merge requests!17update to v0.4.0,!16handle station type
Pipeline #26599 passed
......@@ -13,12 +13,12 @@ class TestDataGenerator:
@pytest.fixture
def gen(self):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'UBA', 'DEBW107', ['o3', 'temp'],
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
'datetime', 'variables', 'o3')
def test_init(self, gen):
assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data')
assert gen.network == 'UBA'
assert gen.network == 'AIRBASE'
assert gen.stations == ['DEBW107']
assert gen.variables == ['o3', 'temp']
assert gen.station_type is None
......@@ -34,7 +34,7 @@ class TestDataGenerator:
def test_repr(self, gen):
path = os.path.join(os.path.dirname(__file__), 'data')
assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], "\
f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \
f"target_dim='variables', target_var='o3', **{{}})".rstrip()
......@@ -43,17 +43,6 @@ class TestDataGenerator:
gen.stations = ['station1', 'station2', 'station3']
assert len(gen) == 3
def test_iter(self, gen):
assert hasattr(gen, '_iterator') is False
iter(gen)
assert hasattr(gen, '_iterator')
assert gen._iterator == 0
def test_next(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
for i, d in enumerate(gen, start=1):
assert i == gen._iterator
def test_getitem(self, caplog, gen):
caplog.set_level(logging.DEBUG)
print(gen)
......@@ -67,6 +56,17 @@ class TestDataGenerator:
assert station[1].data.shape[-1] == gen.window_lead_time
assert station[0].data.shape[1] == gen.window_history + 1
def test_iter(self, gen):
assert hasattr(gen, '_iterator') is False
iter(gen)
assert hasattr(gen, '_iterator')
assert gen._iterator == 0
def test_next(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
for i, d in enumerate(gen, start=1):
assert i == gen._iterator
def test_get_station_key(self, gen):
gen.stations.append("DEBW108")
f = gen.get_station_key
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment