diff --git a/test/test_data_generator.py b/test/test_data_generator.py index d7b58d32b37df71a7236cd6b61db55b371f7f7ef..e932598a6c3054c0ecbae2f17251ed0fd728aaa9 100644 --- a/test/test_data_generator.py +++ b/test/test_data_generator.py @@ -13,12 +13,12 @@ class TestDataGenerator: @pytest.fixture def gen(self): - return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'UBA', 'DEBW107', ['o3', 'temp'], + return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3') def test_init(self, gen): assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data') - assert gen.network == 'UBA' + assert gen.network == 'AIRBASE' assert gen.stations == ['DEBW107'] assert gen.variables == ['o3', 'temp'] assert gen.station_type is None @@ -34,7 +34,7 @@ class TestDataGenerator: def test_repr(self, gen): path = os.path.join(os.path.dirname(__file__), 'data') - assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\ + assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], "\ f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \ f"target_dim='variables', target_var='o3', **{{}})".rstrip() @@ -43,17 +43,6 @@ class TestDataGenerator: gen.stations = ['station1', 'station2', 'station3'] assert len(gen) == 3 - def test_iter(self, gen): - assert hasattr(gen, '_iterator') is False - iter(gen) - assert hasattr(gen, '_iterator') - assert gen._iterator == 0 - - def test_next(self, gen): - gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} - for i, d in enumerate(gen, start=1): - assert i == gen._iterator - def test_getitem(self, caplog, gen): caplog.set_level(logging.DEBUG) print(gen) @@ -67,6 +56,17 @@ class TestDataGenerator: assert station[1].data.shape[-1] == gen.window_lead_time assert station[0].data.shape[1] == gen.window_history + 1 + def test_iter(self, gen): + assert hasattr(gen, '_iterator') is False + iter(gen) + assert hasattr(gen, '_iterator') + assert gen._iterator == 0 + + def test_next(self, gen): + gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} + for i, d in enumerate(gen, start=1): + assert i == gen._iterator + def test_get_station_key(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key