diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index 33550cbf5029320e2f9fe26687eb5022258cf516..0ca04bdefb2c1fe6085b0471a8df5c58cbe5ac19 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -21,7 +21,7 @@ class Distributor(keras.utils.Sequence): elif isinstance(mod_out, list): # multiple output branches, e.g.: [(None, ahead), (None, ahead)] mod_rank = len(mod_out) - else: # pragma: no branch + else: # pragma: no cover raise TypeError("model output shape must either be tuple or list.") return mod_rank @@ -46,10 +46,7 @@ class Distributor(keras.utils.Sequence): raise StopIteration def __len__(self): - if self.batch_size > 1: - num_batch = 0 - for _ in self.distribute_on_batches(fit_call=False): - num_batch += 1 - else: - num_batch = len(self.generator) + num_batch = 0 + for _ in self.distribute_on_batches(fit_call=False): + num_batch += 1 return num_batch diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 894c086eccb0683cb4480be761503b242bb788bb..fc6ce9e9ff9da41eb8caf64f059226903e9d020c 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -27,6 +27,10 @@ class TestDistributor: def model(self): return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False) + @pytest.fixture + def model_with_minor_branch(self): + return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True) + @pytest.fixture def distributor(self, generator, model): return Distributor(generator, model) @@ -35,9 +39,9 @@ class TestDistributor: assert distributor.batch_size == 256 assert distributor.fit_call is True - def test_get_model_rank(self, distributor): + def test_get_model_rank(self, distributor, model_with_minor_branch): assert distributor._get_model_rank() == 1 - distributor.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, True) + distributor.model = model_with_minor_branch assert distributor._get_model_rank() == 2 distributor.model = 1 @@ -45,10 +49,13 @@ class TestDistributor: values = np.zeros((2, 2311, 19)) assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size) - def test_distribute_on_batches(self, generator_two_stations, model): + def test_distribute_on_batches_single_loop(self, generator_two_stations, model): d = Distributor(generator_two_stations, model) for e in d.distribute_on_batches(fit_call=False): assert e[0].shape[0] <= d.batch_size + + def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model): + d = Distributor(generator_two_stations, model) elements = [] for i, e in enumerate(d.distribute_on_batches()): if i < len(d): diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index ca5502e2bf579bd2bf0082433783ec15b72baf4c..7f4ed517f83ce07b2d5e82a05f63e9f4c60375fd 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -41,11 +41,11 @@ class TestPreProcessing: ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'], var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'}) caplog.set_level(logging.INFO) - PreProcessing() - assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') - assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') - assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). Found ' - r'5/5 valid stations.')) + with PreProcessing(): + assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started') + assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started') + assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). ' + r'Found 5/5 valid stations.')) RunEnvironment().__del__() def test_run(self, obj_with_exp_setup):