Skip to content
Snippets Groups Projects
Commit 0f569572 authored by lukas leufen's avatar lukas leufen
Browse files

split a test into 2, removed unnecessary if statement in len method

parent 8fda782b
Branches
Tags
2 merge requests!24include recent development,!20not distributed training
Pipeline #27051 passed
......@@ -21,7 +21,7 @@ class Distributor(keras.utils.Sequence):
elif isinstance(mod_out, list):
# multiple output branches, e.g.: [(None, ahead), (None, ahead)]
mod_rank = len(mod_out)
else: # pragma: no branch
else: # pragma: no cover
raise TypeError("model output shape must either be tuple or list.")
return mod_rank
......@@ -46,10 +46,7 @@ class Distributor(keras.utils.Sequence):
raise StopIteration
def __len__(self):
if self.batch_size > 1:
num_batch = 0
for _ in self.distribute_on_batches(fit_call=False):
num_batch += 1
else:
num_batch = len(self.generator)
return num_batch
......@@ -27,6 +27,10 @@ class TestDistributor:
def model(self):
return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)
@pytest.fixture
def model_with_minor_branch(self):
return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
@pytest.fixture
def distributor(self, generator, model):
return Distributor(generator, model)
......@@ -35,9 +39,9 @@ class TestDistributor:
assert distributor.batch_size == 256
assert distributor.fit_call is True
def test_get_model_rank(self, distributor):
def test_get_model_rank(self, distributor, model_with_minor_branch):
assert distributor._get_model_rank() == 1
distributor.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
distributor.model = model_with_minor_branch
assert distributor._get_model_rank() == 2
distributor.model = 1
......@@ -45,10 +49,13 @@ class TestDistributor:
values = np.zeros((2, 2311, 19))
assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
def test_distribute_on_batches(self, generator_two_stations, model):
def test_distribute_on_batches_single_loop(self, generator_two_stations, model):
d = Distributor(generator_two_stations, model)
for e in d.distribute_on_batches(fit_call=False):
assert e[0].shape[0] <= d.batch_size
def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model):
d = Distributor(generator_two_stations, model)
elements = []
for i, e in enumerate(d.distribute_on_batches()):
if i < len(d):
......
......@@ -41,11 +41,11 @@ class TestPreProcessing:
ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
caplog.set_level(logging.INFO)
PreProcessing()
with PreProcessing():
assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). Found '
r'5/5 valid stations.'))
assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). '
r'Found 5/5 valid stations.'))
RunEnvironment().__del__()
def test_run(self, obj_with_exp_setup):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment