Skip to content
Snippets Groups Projects
Select Git revision
  • patch-1
  • master default protected
  • leuschke1-master-patch-24882
  • leuschke1-master-patch-17157
4 results

cairo-1.17.2-GCCcore-8.3.0.eb

Blame
  • test_data_generator.py 3.70 KiB
    import pytest
    import os
    from src.data_generator import DataGenerator
    import logging
    import numpy as np
    import xarray as xr
    import datetime as dt
    import pandas as pd
    from operator import itemgetter
    
    
    class TestDataGenerator:
    
        @pytest.fixture
        def gen(self):
            return DataGenerator('data', 'UBA', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3')
    
        def test_init(self, gen):
            assert gen.path == os.path.abspath('data')
            assert gen.network == 'UBA'
            assert gen.stations == ['DEBW107']
            assert gen.variables == ['o3', 'temp']
            assert gen.interpolate_dim == 'datetime'
            assert gen.target_dim == 'variables'
            assert gen.target_var == 'o3'
            assert gen.interpolate_method == "linear"
            assert gen.limit_nan_fill == 1
            assert gen.window_history == 7
            assert gen.window_lead_time == 4
            assert gen.transform_method == "standardise"
            assert gen.kwargs == {}
            assert gen.threshold is not None
    
        def test_repr(self, gen):
            path = os.path.join(os.path.dirname(__file__), 'data')
            assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
                                              f"variables=['o3', 'temp'], interpolate_dim='datetime', " \
                                              f"target_dim='variables', target_var='o3', **{{}})".rstrip()
    
        def test_len(self, gen):
            assert len(gen) == 1
            gen.stations = ['station1', 'station2', 'station3']
            assert len(gen) == 3
    
        def test_iter(self, gen):
            assert hasattr(gen, 'iterator') is False
            iter(gen)
            assert hasattr(gen, 'iterator')
            assert gen._iterator == 0
    
        def test_next(self, gen):
            gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
            for i, d in enumerate(gen, start=1):
                assert i == gen._iterator
    
        def test_getitem(self, gen):
            gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
            station = gen["DEBW107"]
            assert len(station) == 2
            assert station[0].Stations.data == "DEBW107"
            assert station[0].data.shape[1:] == (8, 1, 2)
            assert station[1].data.shape[-1] == gen.window_lead_time
            assert station[0].data.shape[1] == gen.window_history + 1
    
        def test_threshold_setup(self, gen):
            def res(arg, val):
                gen.kwargs[arg] = val
                return list(map(float, gen.threshold_setup()))
            compare = np.testing.assert_array_almost_equal
            assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None
            assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None
            assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None
            assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None
    
        def test_get_key_representation(self, gen):
            gen.stations.append("DEBW108")
            f = gen.get_station_key
            iter(gen)
            assert f(None) == "DEBW107"
            with pytest.raises(KeyError) as e:
                f([None, None])
            assert "More than one key was given: [None, None]" in e.value.args[0]
            assert f(1) == "DEBW108"
            assert f([1]) == "DEBW108"
            with pytest.raises(KeyError) as e:
                f(3)
            assert "3 is not in range(0, 2)" in e.value.args[0]
            assert f("DEBW107") == "DEBW107"
            assert f(["DEBW108"]) == "DEBW108"
            with pytest.raises(KeyError) as e:
                f("DEBW999")
            assert "DEBW999 is not in stations" in e.value.args[0]
            with pytest.raises(KeyError) as e:
                f(6.5)
            assert "key has to be from Union[str, int]. Given was 6.5 (float)"