Skip to content
Snippets Groups Projects
Select Git revision
  • e80be31962af928323c01ffec18dc0f4c89d8b91
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

test_data_generator.py

Blame
  • user avatar
    lukas leufen authored
    e80be319
    History
    test_data_generator.py 3.70 KiB
    import pytest
    import os
    from src.data_generator import DataGenerator
    import logging
    import numpy as np
    import xarray as xr
    import datetime as dt
    import pandas as pd
    from operator import itemgetter
    
    
    class TestDataGenerator:
    
        @pytest.fixture
        def gen(self):
            return DataGenerator('data', 'UBA', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3')
    
        def test_init(self, gen):
            assert gen.path == os.path.abspath('data')
            assert gen.network == 'UBA'
            assert gen.stations == ['DEBW107']
            assert gen.variables == ['o3', 'temp']
            assert gen.interpolate_dim == 'datetime'
            assert gen.target_dim == 'variables'
            assert gen.target_var == 'o3'
            assert gen.interpolate_method == "linear"
            assert gen.limit_nan_fill == 1
            assert gen.window_history == 7
            assert gen.window_lead_time == 4
            assert gen.transform_method == "standardise"
            assert gen.kwargs == {}
            assert gen.threshold is not None
    
        def test_repr(self, gen):
            path = os.path.join(os.path.dirname(__file__), 'data')
            assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='UBA', stations=['DEBW107'], "\
                                              f"variables=['o3', 'temp'], interpolate_dim='datetime', " \
                                              f"target_dim='variables', target_var='o3', **{{}})".rstrip()
    
        def test_len(self, gen):
            assert len(gen) == 1
            gen.stations = ['station1', 'station2', 'station3']
            assert len(gen) == 3
    
        def test_iter(self, gen):
            assert hasattr(gen, 'iterator') is False
            iter(gen)
            assert hasattr(gen, 'iterator')
            assert gen._iterator == 0
    
        def test_next(self, gen):
            gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
            for i, d in enumerate(gen, start=1):
                assert i == gen._iterator
    
        def test_getitem(self, gen):
            gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}}
            station = gen["DEBW107"]
            assert len(station) == 2
            assert station[0].Stations.data == "DEBW107"
            assert station[0].data.shape[1:] == (8, 1, 2)
            assert station[1].data.shape[-1] == gen.window_lead_time
            assert station[0].data.shape[1] == gen.window_history + 1
    
        def test_threshold_setup(self, gen):
            def res(arg, val):
                gen.kwargs[arg] = val
                return list(map(float, gen.threshold_setup()))
            compare = np.testing.assert_array_almost_equal
            assert compare(res('', ''), np.linspace(0, 100, 200), decimal=3) is None
            assert compare(res('thr_min', 10), np.linspace(10, 100, 200), decimal=3) is None
            assert compare(res('thr_max', 40), np.linspace(10, 40, 200), decimal=3) is None
            assert compare(res('thr_number_of_steps', 10), np.linspace(10, 40, 10), decimal=3) is None
    
        def test_get_key_representation(self, gen):
            gen.stations.append("DEBW108")
            f = gen.get_station_key
            iter(gen)
            assert f(None) == "DEBW107"
            with pytest.raises(KeyError) as e:
                f([None, None])
            assert "More than one key was given: [None, None]" in e.value.args[0]
            assert f(1) == "DEBW108"
            assert f([1]) == "DEBW108"
            with pytest.raises(KeyError) as e:
                f(3)
            assert "3 is not in range(0, 2)" in e.value.args[0]
            assert f("DEBW107") == "DEBW107"
            assert f(["DEBW108"]) == "DEBW108"
            with pytest.raises(KeyError) as e:
                f("DEBW999")
            assert "DEBW999 is not in stations" in e.value.args[0]
            with pytest.raises(KeyError) as e:
                f(6.5)
            assert "key has to be from Union[str, int]. Given was 6.5 (float)"