Skip to content
Snippets Groups Projects
Select Git revision
  • df43d9db86020cd23ad6a568312fd5fa814bf85c
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

mlt_modules_hdfml.sh

Blame
  • test_data_distributor.py 2.96 KiB
    import math
    import os
    import shutil
    
    import keras
    import numpy as np
    import pytest
    
    from src.data_handling.data_distributor import Distributor
    from src.data_handling.data_generator import DataGenerator
    from test.test_modules.test_training import my_test_model
    
    
    class TestDistributor:
    
        @pytest.fixture
        def generator(self):
            return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
                                 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
    
        @pytest.fixture
        def generator_two_stations(self):
            return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
                                 ['o3', 'temp'], 'datetime', 'variables', 'o3',
                                 statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
    
        @pytest.fixture
        def model(self):
            return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)
    
        @pytest.fixture
        def model_with_minor_branch(self):
            return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
    
        @pytest.fixture
        def distributor(self, generator, model):
            return Distributor(generator, model)
    
        def test_init_defaults(self, distributor):
            assert distributor.batch_size == 256
            assert distributor.fit_call is True
    
        def test_get_model_rank(self, distributor, model_with_minor_branch):
            assert distributor._get_model_rank() == 1
            distributor.model = model_with_minor_branch
            assert distributor._get_model_rank() == 2
            distributor.model = 1
    
        def test_get_number_of_mini_batches(self, distributor):
            values = np.zeros((2, 2311, 19))
            assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
    
        def test_distribute_on_batches_single_loop(self,  generator_two_stations, model):
            d = Distributor(generator_two_stations, model)
            for e in d.distribute_on_batches(fit_call=False):
                assert e[0].shape[0] <= d.batch_size
    
        def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model):
            d = Distributor(generator_two_stations, model)
            elements = []
            for i, e in enumerate(d.distribute_on_batches()):
                if i < len(d):
                    elements.append(e[0])
                elif i == 2*len(d):  # check if all elements are repeated
                    assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None
                else:  # break when 3rd iteration starts (is called as infinite loop)
                    break
    
        def test_len(self, distributor):
            assert len(distributor) == math.ceil(len(distributor.generator[0][0]) / 256)
    
        def test_len_two_stations(self, generator_two_stations, model):
            gen = generator_two_stations
            d = Distributor(gen, model)
            expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
            assert len(d) == expected