import pytest import os from src.data_generator import DataGenerator import logging import numpy as np import xarray as xr import datetime as dt import pandas as pd from operator import itemgetter class TestDataGenerator: @pytest.fixture def gen(self): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'UBA', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3') def test_init(self, gen): assert gen.path == os.path.join(os.path.dirname(__file__), 'data') assert gen.network == 'UBA' assert gen.stations == ['DEBW107'] assert gen.variables == ['o3', 'temp'] assert gen.interpolate_dim == 'datetime' assert gen.target_dim == 'variables' assert gen.target_var == 'o3' assert gen.interpolate_method == "linear" assert gen.limit_nan_fill == 1 assert gen.window_history == 7 assert gen.window_lead_time == 4 assert gen.transform_method == "standardise" assert gen.kwargs == {} assert gen.threshold is not None def test_repr(self, gen): path = os.path.join(os.path.dirname(__file__), 'data') assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\ f"variables=['o3', 'temp'], interpolate_dim='datetime', " \ f"target_dim='variables', target_var='o3', **{{}})".rstrip() def test_len(self, gen): assert len(gen) == 1 gen.stations = ['station1', 'station2', 'station3'] assert len(gen) == 3 def test_iter(self, gen): assert hasattr(gen, '_iterator') is False iter(gen) assert hasattr(gen, '_iterator') assert gen._iterator == 0 def test_next(self, gen): gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} for i, d in enumerate(gen, start=1): assert i == gen._iterator def test_getitem(self, gen): gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} station = gen["DEBW107"] assert len(station) == 2 assert station[0].Stations.data == "DEBW107" assert station[0].data.shape[1:] == (8, 1, 2) assert station[1].data.shape[-1] == gen.window_lead_time assert station[0].data.shape[1] == gen.window_history + 1 def test_threshold_setup(self, gen): def res(arg, val): gen.kwargs[arg] = val return list(map(float, gen.threshold_setup())) compare = np.testing.assert_array_almost_equal assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None def test_get_key_representation(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key iter(gen) assert f(None) == "DEBW107" with pytest.raises(KeyError) as e: f([None, None]) assert "More than one key was given: [None, None]" in e.value.args[0] assert f(1) == "DEBW108" assert f([1]) == "DEBW108" with pytest.raises(KeyError) as e: f(3) assert "3 is not in range(0, 2)" in e.value.args[0] assert f("DEBW107") == "DEBW107" assert f(["DEBW108"]) == "DEBW108" with pytest.raises(KeyError) as e: f("DEBW999") assert "DEBW999 is not in stations" in e.value.args[0] with pytest.raises(KeyError) as e: f(6.5) assert "key has to be from Union[str, int]. Given was 6.5 (float)"