Select Git revision
test_data_distributor.py 2.96 KiB
import math
import os
import shutil
import keras
import numpy as np
import pytest
from src.data_handling.data_distributor import Distributor
from src.data_handling.data_generator import DataGenerator
from test.test_modules.test_training import my_test_model
class TestDistributor:
@pytest.fixture
def generator(self):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'],
'datetime', 'variables', 'o3', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
@pytest.fixture
def generator_two_stations(self):
return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', ['DEBW107', 'DEBW013'],
['o3', 'temp'], 'datetime', 'variables', 'o3',
statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'})
@pytest.fixture
def model(self):
return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)
@pytest.fixture
def model_with_minor_branch(self):
return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
@pytest.fixture
def distributor(self, generator, model):
return Distributor(generator, model)
def test_init_defaults(self, distributor):
assert distributor.batch_size == 256
assert distributor.fit_call is True
def test_get_model_rank(self, distributor, model_with_minor_branch):
assert distributor._get_model_rank() == 1
distributor.model = model_with_minor_branch
assert distributor._get_model_rank() == 2
distributor.model = 1
def test_get_number_of_mini_batches(self, distributor):
values = np.zeros((2, 2311, 19))
assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
def test_distribute_on_batches_single_loop(self, generator_two_stations, model):
d = Distributor(generator_two_stations, model)
for e in d.distribute_on_batches(fit_call=False):
assert e[0].shape[0] <= d.batch_size
def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model):
d = Distributor(generator_two_stations, model)
elements = []
for i, e in enumerate(d.distribute_on_batches()):
if i < len(d):
elements.append(e[0])
elif i == 2*len(d): # check if all elements are repeated
assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None
else: # break when 3rd iteration starts (is called as infinite loop)
break
def test_len(self, distributor):
assert len(distributor) == math.ceil(len(distributor.generator[0][0]) / 256)
def test_len_two_stations(self, generator_two_stations, model):
gen = generator_two_stations
d = Distributor(gen, model)
expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
assert len(d) == expected