From 0f569572ae50d452c9048b0e5c7df547dd9b7ded Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Fri, 6 Dec 2019 12:30:00 +0100 Subject: [PATCH] split a test into 2, removed unnecessary if statement in len method --- src/data_handling/data_distributor.py | 11 ++++------- test/test_data_handling/test_data_distributor.py | 13 ++++++++++--- test/test_modules/test_pre_processing.py | 10 +++++----- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index 33550cbf..0ca04bde 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -21,7 +21,7 @@ class Distributor(keras.utils.Sequence): elif isinstance(mod_out, list): # multiple output branches, e.g.: [(None, ahead), (None, ahead)] mod_rank = len(mod_out) - else: # pragma: no branch + else: # pragma: no cover raise TypeError("model output shape must either be tuple or list.") return mod_rank @@ -46,10 +46,7 @@ class Distributor(keras.utils.Sequence): raise StopIteration def __len__(self): - if self.batch_size > 1: - num_batch = 0 - for _ in self.distribute_on_batches(fit_call=False): - num_batch += 1 - else: - num_batch = len(self.generator) + num_batch = 0 + for _ in self.distribute_on_batches(fit_call=False): + num_batch += 1 return num_batch diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 894c086e..fc6ce9e9 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -27,6 +27,10 @@ class TestDistributor: def model(self): return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False) + @pytest.fixture + def model_with_minor_branch(self): + return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True) + @pytest.fixture def distributor(self, generator, model): return Distributor(generator, model) @@ -35,9 +39,9 @@ class TestDistributor: assert distributor.batch_size == 256 assert distributor.fit_call is True - def test_get_model_rank(self, distributor): + def test_get_model_rank(self, distributor, model_with_minor_branch): assert distributor._get_model_rank() == 1 - distributor.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, True) + distributor.model = model_with_minor_branch assert distributor._get_model_rank() == 2 distributor.model = 1 @@ -45,10 +49,13 @@ class TestDistributor: values = np.zeros((2, 2311, 19)) assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size) - def test_distribute_on_batches(self, generator_two_stations, model): + def test_distribute_on_batches_single_loop(self, generator_two_stations, model): d = Distributor(generator_two_stations, model) for e in d.distribute_on_batches(fit_call=False): assert e[0].shape[0] <= d.batch_size + + def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model): + d = Distributor(generator_two_stations, model) elements = [] for i, e in enumerate(d.distribute_on_batches()): if i < len(d): diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index ca5502e2..7f4ed517 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -41,11 +41,11 @@ class TestPreProcessing: ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) caplog.set_level(logging.INFO) - PreProcessing() - assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') - assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). Found ' - r'5/5 valid stations.')) + with PreProcessing(): + assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') + assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). ' + r'Found 5/5 valid stations.')) RunEnvironment().__del__() def test_run(self, obj_with_exp_setup): -- GitLab