Skip to content
Snippets Groups Projects
Select Git revision
  • 88cc2f6030921b691c4026d2f4d9455ee9ce897c
  • 2022 default
  • 2021
  • master protected
  • 2021
5 results

README.md

Blame
  • test_data_generator.py 3.19 KiB
    import pytest
    import os
    from src.data_generator import DataGenerator
    import logging
    import numpy as np
    import xarray as xr
    import datetime as dt
    import pandas as pd
    from operator import itemgetter
    
    
    class TestDataGenerator:
    
        @pytest.fixture
        def gen(self):
            return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'UBA', 'DEBW107', ['o3', 'temp'],
                                 'datetime', 'variables', 'o3')
    
        def test_init(self, gen):
            assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data')
            assert gen.network == 'UBA'
            assert gen.stations == ['DEBW107']
            assert gen.variables == ['o3', 'temp']
            assert gen.interpolate_dim == 'datetime'
            assert gen.target_dim == 'variables'
            assert gen.target_var == 'o3'
            assert gen.interpolate_method == "linear"
            assert gen.limit_nan_fill == 1
            assert gen.window_history == 7
            assert gen.window_lead_time == 4
            assert gen.transform_method == "standardise"
            assert gen.kwargs == {}
    
        def test_repr(self, gen):
            path = os.path.join(os.path.dirname(__file__), 'data')
            assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
                                              f"variables=['o3', 'temp'], interpolate_dim='datetime', " \
                                              f"target_dim='variables', target_var='o3', **{{}})".rstrip()
    
        def test_len(self, gen):
            assert len(gen) == 1
            gen.stations = ['station1', 'station2', 'station3']
            assert len(gen) == 3
    
        def test_iter(self, gen):
            assert hasattr(gen, '_iterator') is False
            iter(gen)
            assert hasattr(gen, '_iterator')
            assert gen._iterator == 0
    
        def test_next(self, gen):
            gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
            for i, d in enumerate(gen, start=1):
                assert i == gen._iterator
    
        def test_getitem(self, gen):
            gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
            station = gen["DEBW107"]
            assert len(station) == 2
            assert station[0].Stations.data == "DEBW107"
            assert station[0].data.shape[1:] == (8, 1, 2)
            assert station[1].data.shape[-1] == gen.window_lead_time
            assert station[0].data.shape[1] == gen.window_history + 1
    
        def test_get_station_key(self, gen):
            gen.stations.append("DEBW108")
            f = gen.get_station_key
            iter(gen)
            assert f(None) == "DEBW107"
            with pytest.raises(KeyError) as e:
                f([None, None])
            assert "More than one key was given: [None, None]" in e.value.args[0]
            assert f(1) == "DEBW108"
            assert f([1]) == "DEBW108"
            with pytest.raises(KeyError) as e:
                f(3)
            assert "3 is not in range(0, 2)" in e.value.args[0]
            assert f("DEBW107") == "DEBW107"
            assert f(["DEBW108"]) == "DEBW108"
            with pytest.raises(KeyError) as e:
                f("DEBW999")
            assert "DEBW999 is not in stations" in e.value.args[0]
            with pytest.raises(KeyError) as e:
                f(6.5)
            assert "key has to be from Union[str, int]. Given was 6.5 (float)"