diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index 33550cbf5029320e2f9fe26687eb5022258cf516..0ca04bdefb2c1fe6085b0471a8df5c58cbe5ac19 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -21,7 +21,7 @@ class Distributor(keras.utils.Sequence):
         elif isinstance(mod_out, list):
             # multiple output branches, e.g.: [(None, ahead), (None, ahead)]
             mod_rank = len(mod_out)
-        else:  # pragma: no branch
+        else:  # pragma: no cover
             raise TypeError("model output shape must either be tuple or list.")
         return mod_rank
 
@@ -46,10 +46,7 @@ class Distributor(keras.utils.Sequence):
                             raise StopIteration
 
     def __len__(self):
-        if self.batch_size > 1:
-            num_batch = 0
-            for _ in self.distribute_on_batches(fit_call=False):
-                num_batch += 1
-        else:
-            num_batch = len(self.generator)
+        num_batch = 0
+        for _ in self.distribute_on_batches(fit_call=False):
+            num_batch += 1
         return num_batch
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index 894c086eccb0683cb4480be761503b242bb788bb..fc6ce9e9ff9da41eb8caf64f059226903e9d020c 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -27,6 +27,10 @@ class TestDistributor:
     def model(self):
         return my_test_model(keras.layers.PReLU, 5, 3, 0.1, False)
 
+    @pytest.fixture
+    def model_with_minor_branch(self):
+        return my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
+
     @pytest.fixture
     def distributor(self, generator, model):
         return Distributor(generator, model)
@@ -35,9 +39,9 @@ class TestDistributor:
         assert distributor.batch_size == 256
         assert distributor.fit_call is True
 
-    def test_get_model_rank(self, distributor):
+    def test_get_model_rank(self, distributor, model_with_minor_branch):
         assert distributor._get_model_rank() == 1
-        distributor.model = my_test_model(keras.layers.PReLU, 5, 3, 0.1, True)
+        distributor.model = model_with_minor_branch
         assert distributor._get_model_rank() == 2
         distributor.model = 1
 
@@ -45,10 +49,13 @@ class TestDistributor:
         values = np.zeros((2, 2311, 19))
         assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size)
 
-    def test_distribute_on_batches(self,  generator_two_stations, model):
+    def test_distribute_on_batches_single_loop(self,  generator_two_stations, model):
         d = Distributor(generator_two_stations, model)
         for e in d.distribute_on_batches(fit_call=False):
             assert e[0].shape[0] <= d.batch_size
+
+    def test_distribute_on_batches_infinite_loop(self, generator_two_stations, model):
+        d = Distributor(generator_two_stations, model)
         elements = []
         for i, e in enumerate(d.distribute_on_batches()):
             if i < len(d):
diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py
index ca5502e2bf579bd2bf0082433783ec15b72baf4c..7f4ed517f83ce07b2d5e82a05f63e9f4c60375fd 100644
--- a/test/test_modules/test_pre_processing.py
+++ b/test/test_modules/test_pre_processing.py
@@ -41,11 +41,11 @@ class TestPreProcessing:
         ExperimentSetup(parser_args={}, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087'],
                         var_all_dict={'o3': 'dma8eu', 'temp': 'maximum'})
         caplog.set_level(logging.INFO)
-        PreProcessing()
-        assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
-        assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
-        assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). Found '
-                                                                    r'5/5 valid stations.'))
+        with PreProcessing():
+            assert caplog.record_tuples[0] == ('root', 20, 'PreProcessing started')
+            assert caplog.record_tuples[1] == ('root', 20, 'check valid stations started')
+            assert caplog.record_tuples[-1] == ('root', 20, PyTestRegex(r'run for \d+\.\d+s to check 5 station\(s\). '
+                                                                        r'Found 5/5 valid stations.'))
         RunEnvironment().__del__()
 
     def test_run(self, obj_with_exp_setup):