Select Git revision
cairo-1.17.2-GCCcore-8.3.0.eb
test_data_generator.py 3.70 KiB
import pytest
import os
from src.data_generator import DataGenerator
import logging
import numpy as np
import xarray as xr
import datetime as dt
import pandas as pd
from operator import itemgetter
class TestDataGenerator:
@pytest.fixture
def gen(self):
return DataGenerator('data', 'UBA', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3')
def test_init(self, gen):
assert gen.path == os.path.abspath('data')
assert gen.network == 'UBA'
assert gen.stations == ['DEBW107']
assert gen.variables == ['o3', 'temp']
assert gen.interpolate_dim == 'datetime'
assert gen.target_dim == 'variables'
assert gen.target_var == 'o3'
assert gen.interpolate_method == "linear"
assert gen.limit_nan_fill == 1
assert gen.window_history == 7
assert gen.window_lead_time == 4
assert gen.transform_method == "standardise"
assert gen.kwargs == {}
assert gen.threshold is not None
def test_repr(self, gen):
path = os.path.join(os.path.dirname(__file__), 'data')
assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
f"variables=['o3', 'temp'], interpolate_dim='datetime', " \
f"target_dim='variables', target_var='o3', **{{}})".rstrip()
def test_len(self, gen):
assert len(gen) == 1
gen.stations = ['station1', 'station2', 'station3']
assert len(gen) == 3
def test_iter(self, gen):
assert hasattr(gen, 'iterator') is False
iter(gen)
assert hasattr(gen, 'iterator')
assert gen._iterator == 0
def test_next(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
for i, d in enumerate(gen, start=1):
assert i == gen._iterator
def test_getitem(self, gen):
gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
station = gen["DEBW107"]
assert len(station) == 2
assert station[0].Stations.data == "DEBW107"
assert station[0].data.shape[1:] == (8, 1, 2)
assert station[1].data.shape[-1] == gen.window_lead_time
assert station[0].data.shape[1] == gen.window_history + 1
def test_threshold_setup(self, gen):
def res(arg, val):
gen.kwargs[arg] = val
return list(map(float, gen.threshold_setup()))
compare = np.testing.assert_array_almost_equal
assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None
assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None
assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None
assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None
def test_get_key_representation(self, gen):
gen.stations.append("DEBW108")
f = gen.get_station_key
iter(gen)
assert f(None) == "DEBW107"
with pytest.raises(KeyError) as e:
f([None, None])
assert "More than one key was given: [None, None]" in e.value.args[0]
assert f(1) == "DEBW108"
assert f([1]) == "DEBW108"
with pytest.raises(KeyError) as e:
f(3)
assert "3 is not in range(0, 2)" in e.value.args[0]
assert f("DEBW107") == "DEBW107"
assert f(["DEBW108"]) == "DEBW108"
with pytest.raises(KeyError) as e:
f("DEBW999")
assert "DEBW999 is not in stations" in e.value.args[0]
with pytest.raises(KeyError) as e:
f(6.5)
assert "key has to be from Union[str, int]. Given was 6.5 (float)"