import math import os import keras import numpy as np import pytest from src.data_handling.data_distributor import Distributor from src.data_handling.data_generator import DataGenerator from test.test_modules.test_training import my_test_model class TestDistributor: @pytest.fixture def generator(self): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) @pytest.fixture def generator_two_stations(self): return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'], ['o3', 'temp'], 'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) @pytest.fixture def model(self): return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False) @pytest.fixture def model_with_minor_branch(self): return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True) @pytest.fixture def distributor(self, generator, model): return Distributor(generator, model) def test_init_defaults(self, distributor): assert distributor.batch_size == 256 assert distributor.fit_call is True assert distributor.do_data_permutation is False def test_get_model_rank(self, distributor, model_with_minor_branch): assert distributor._get_model_rank() == 1 distributor.model = model_with_minor_branch assert distributor._get_model_rank() == 2 distributor.model = 1 def test_get_number_of_mini_batches(self, distributor): values = np.zeros((2, 2311, 19)) assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size) def test_distribute_on_batches_single_loop(self, generator_two_stations, model): d = Distributor(generator_two_stations, model) for e in d.distribute_on_batches(fit_call=False): assert e[0].shape[0] <= d.batch_size def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model): d = Distributor(generator_two_stations, model) elements = [] for i, e in enumerate(d.distribute_on_batches()): if i < len(d): elements.append(e[0]) elif i == 2*len(d): # check if all elements are repeated assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None else: # break when 3rd iteration starts (is called as infinite loop) break def test_len(self, distributor): assert len(distributor) == math.ceil(len(distributor.generator[0][0]) / 256) def test_len_two_stations(self, generator_two_stations, model): gen = generator_two_stations d = Distributor(gen, model) expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) assert len(d) == expected def test_permute_data_no_permutation(self, distributor): x = np.array(range(20)).reshape(2, 10).T y = np.array(range(10)).reshape(10, 1) x_perm, y_perm = distributor._permute_data(x, y) assert np.testing.assert_equal(x, x_perm) is None assert np.testing.assert_equal(y, y_perm) is None def test_permute_data(self, distributor): x = np.array(range(20)).reshape(2, 10).T y = np.array(range(10)).reshape(10, 1) distributor.do_data_permutation = True x_perm, y_perm = distributor._permute_data(x, y) assert x_perm[0, 0] == y_perm[0] assert x_perm[0, 1] == y_perm[0] + 10 assert x_perm[5, 0] == y_perm[5] assert x_perm[5, 1] == y_perm[5] + 10 assert x_perm[-1, 0] == y_perm[-1] assert x_perm[-1, 1] == y_perm[-1] + 10 # resort x_perm and compare if equal to x x_perm.sort(axis=0) y_perm.sort(axis=0) assert np.testing.assert_equal(x, x_perm) is None assert np.testing.assert_equal(y, y_perm) is None