diff --git a/mlair/data_handler/abstract_data_handler.py b/mlair/data_handler/abstract_data_handler.py
index 9ea163fcad2890580e9c44e4bda0627d6419dc9f..a82e5005e8b30f9e3978ae61859e6b80746d95f1 100644
--- a/mlair/data_handler/abstract_data_handler.py
+++ b/mlair/data_handler/abstract_data_handler.py
@@ -22,6 +22,9 @@ class AbstractDataHandler(object):
         """Return initialised class."""
         return cls(*args, **kwargs)
 
+    def __len__(self, upsampling=False):
+        raise NotImplementedError
+
     @classmethod
     def requirements(cls, skip_args=None):
         """Return requirements and own arguments without duplicates."""
diff --git a/mlair/data_handler/default_data_handler.py b/mlair/data_handler/default_data_handler.py
index 06fd5772db2e18ff70792e8126da83ccfac46f82..69c9537b10ca583adf84480636680a99ab265a67 100644
--- a/mlair/data_handler/default_data_handler.py
+++ b/mlair/data_handler/default_data_handler.py
@@ -55,6 +55,8 @@ class DefaultDataHandler(AbstractDataHandler):
         self._X_extreme = None
         self._Y_extreme = None
         self._data_intersection = None
+        self._len = None
+        self._len_upsampling = None
         self._use_multiprocessing = use_multiprocessing
         self._max_number_multiprocessing = max_number_multiprocessing
         _name_affix = str(f"{str(self.id_class)}_{name_affix}" if name_affix is not None else id(self))
@@ -134,9 +136,11 @@ class DefaultDataHandler(AbstractDataHandler):
     def __repr__(self):
         return str(self._collection[0])
 
-    def __len__(self):
-        if self._data_intersection is not None:
-            return len(self._data_intersection)
+    def __len__(self, upsampling=False):
+        if upsampling is False:
+            return self._len
+        else:
+            return self._len_upsampling
 
     def get_X_original(self):
         X = []
@@ -178,6 +182,7 @@ class DefaultDataHandler(AbstractDataHandler):
             Y = Y_original.sel({dim: intersect})
         self._data_intersection = intersect
         self._X, self._Y = X, Y
+        self._len = len(self._data_intersection)
 
     def get_observation(self):
         dim = self.time_dim
@@ -212,6 +217,7 @@ class DefaultDataHandler(AbstractDataHandler):
         if extreme_values is None:
             logging.debug(f"No extreme values given, skip multiply extremes")
             self._X_extreme, self._Y_extreme = self._X, self._Y
+            self._len_upsampling = self._len
             return
 
         # check type if inputs
@@ -247,6 +253,7 @@ class DefaultDataHandler(AbstractDataHandler):
 
             self._Y_extreme = xr.concat([Y, extremes_Y], dim=dim)
             self._X_extreme = list(map(lambda x1, x2: xr.concat([x1, x2], dim=dim), X, extremes_X))
+        self._len_upsampling = len(self._X_extreme[0].coords[dim])
 
     @staticmethod
     def _add_timedelta(data, dim, timedelta):
diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py
index b838626221b18c3f3b55ba15e43ae152d83ce5c6..cedb06409a99bf66f4a3a1e695de6059bcd9e143 100644
--- a/mlair/data_handler/iterator.py
+++ b/mlair/data_handler/iterator.py
@@ -86,7 +86,7 @@ class KerasIterator(keras.utils.Sequence):
         self.upsampling = upsampling
         self.indexes: list = []
         self._cleanup_path(batch_path)
-        self._prepare_batches_parallel(use_multiprocessing, max_number_multiprocessing)
+        self._prepare_batches(use_multiprocessing, max_number_multiprocessing)
 
     def __len__(self) -> int:
         return len(self.indexes)
@@ -160,7 +160,7 @@ class KerasIterator(keras.utils.Sequence):
             index += 1
         self.indexes = np.arange(0, index).tolist()
 
-    def _prepare_batches_parallel(self, use_multiprocessing=False, max_process=1) -> None:
+    def _prepare_batches(self, use_multiprocessing=False, max_process=1) -> None:
         """
         Prepare all batches as locally stored files.
 
@@ -183,7 +183,7 @@ class KerasIterator(keras.utils.Sequence):
             pool = None
             output = None
         for data in self._collection:
-            length = len(data)
+            length = data.__len__(self.upsampling)
             batches = _get_number_of_mini_batches(length, self.batch_size)
             if pool is None:
                 res = f_proc(data, self.upsampling, mod_rank, self.batch_size, self._path, index)
@@ -209,7 +209,6 @@ class KerasIterator(keras.utils.Sequence):
                 _save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index)
                 index += 1
         self.indexes = np.arange(0, index).tolist()
-        logging.warning(f"hightst index is {index}")
         if pool is not None:
             pool.join()
 
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index d23d68a823c03a82f88591df5c66e762889f8c93..0d7bb98f109b612cf3cffc3dc31541bb1733c541 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -102,6 +102,7 @@ class Training(RunEnvironment):
         """
         self.model.make_predict_function()
 
+    @TimeTrackingWrapper
     def _set_gen(self, mode: str) -> None:
         """
         Set and distribute the generators for given mode regarding batch size.
diff --git a/test/test_data_handler/test_iterator.py b/test/test_data_handler/test_iterator.py
index bb8ecb5d216519b3662a5baa4d463780b4c29d8c..fe740c094b41df7da45bd0f76d678830f95e1902 100644
--- a/test/test_data_handler/test_iterator.py
+++ b/test/test_data_handler/test_iterator.py
@@ -1,4 +1,5 @@
 from mlair.data_handler.iterator import DataCollection, StandardIterator, KerasIterator
+from mlair.data_handler.iterator import _get_number_of_mini_batches, _get_batch, _permute_data, _save_to_pickle, f_proc
 from mlair.helpers.testing import PyTestAllEqual
 from mlair.model_modules.model_class import MyBranchedModel
 from mlair.model_modules.fully_connected_networks import FCN_64_32_16
@@ -89,6 +90,14 @@ class DummyData:
     def __init__(self, number_of_samples=np.random.randint(100, 150)):
         np.random.seed(45)
         self.number_of_samples = number_of_samples
+        self._len = self.number_of_samples
+        self._len_upsampling = self.number_of_samples
+
+    def __len__(self, upsampling=False):
+        if upsampling is False:
+            return self._len
+        else:
+            return self._len_upsampling
 
     def get_X(self, upsampling=False, as_numpy=True):
         np.random.seed(45)
@@ -152,13 +161,6 @@ class TestKerasIterator:
         iterator._cleanup_path(path, create_new=False)
         assert os.path.exists(path) is False
 
-    def test_get_number_of_mini_batches(self):
-        iterator = object.__new__(KerasIterator)
-        iterator.batch_size = 36
-        assert iterator._get_number_of_mini_batches(30) == 0
-        assert iterator._get_number_of_mini_batches(40) == 1
-        assert iterator._get_number_of_mini_batches(72) == 2
-
     def test_len(self):
         iterator = object.__new__(KerasIterator)
         iterator.indexes = [0, 1, 2, 3, 4, 5]
@@ -175,25 +177,6 @@ class TestKerasIterator:
         for i in range(3):
             assert PyTestAllEqual([new_arr[i], test_arr[i]])
 
-    def test_get_batch(self):
-        arr = DummyData(20).get_X()
-        iterator = object.__new__(KerasIterator)
-        iterator.batch_size = 19
-        batch1 = iterator._get_batch(arr, 0)
-        assert batch1[0].shape[0] == 19
-        batch2 = iterator._get_batch(arr, 1)
-        assert batch2[0].shape[0] == 1
-
-    def test_save_to_pickle(self, path):
-        os.makedirs(path)
-        d = DummyData(20)
-        X, Y = d.get_X(), d.get_Y()
-        iterator = object.__new__(KerasIterator)
-        iterator._path = os.path.join(path, "%i.pickle")
-        assert os.path.exists(iterator._path % 2) is False
-        iterator._save_to_pickle(X=X, Y=Y, index=2)
-        assert os.path.exists(iterator._path % 2) is True
-
     def test_prepare_batches(self, collection, path):
         iterator = object.__new__(KerasIterator)
         iterator._collection = collection
@@ -292,14 +275,112 @@ class TestKerasIterator:
         with pytest.raises(TypeError):
             iterator._get_model_rank()
 
-    def test_permute(self):
-        iterator = object.__new__(KerasIterator)
+
+class TestGetNumberOfMiniBatches:
+
+    def test_get_number_of_mini_batches(self):
+        batch_size = 36
+        assert _get_number_of_mini_batches(30, batch_size) == 0
+        assert _get_number_of_mini_batches(40, batch_size) == 1
+        assert _get_number_of_mini_batches(72, batch_size) == 2
+
+
+class TestGetBatch:
+
+    def test_get_batch(self):
+        arr = DummyData(20).get_X()
+        batch_size = 19
+        batch1 = _get_batch(arr, 0, batch_size)
+        assert batch1[0].shape[0] == 19
+        batch2 = _get_batch(arr, 1, batch_size)
+        assert batch2[0].shape[0] == 1
+
+
+class TestSaveToPickle:
+
+    @pytest.fixture
+    def path(self):
+        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+        shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
+        yield p
+        shutil.rmtree(p, ignore_errors=True)
+
+    def test_save_to_pickle(self, path):
+        os.makedirs(path)
+        d = DummyData(20)
+        X, Y = d.get_X(), d.get_Y()
+        _path = os.path.join(path, "%i.pickle")
+        assert os.path.exists(_path % 2) is False
+        _save_to_pickle(_path, X=X, Y=Y, index=2)
+        assert os.path.exists(_path % 2) is True
+
+
+class TestPermuteData:
+
+    def test_permute_data(self):
         X = [np.array([[1, 2, 3, 4],
                        [1.1, 2.1, 3.1, 4.1],
                        [1.2, 2.2, 3.2, 4.2]], dtype="f2")]
         Y = [np.array([1, 2, 3])]
-        X_p, Y_p = iterator._permute_data(X, Y)
+        X_p, Y_p = _permute_data(X, Y)
         assert X_p[0].shape == X[0].shape
         assert Y_p[0].shape == Y[0].shape
         assert np.testing.assert_almost_equal(X_p[0].sum(), X[0].sum(), 2) is None
         assert np.testing.assert_almost_equal(Y_p[0].sum(), Y[0].sum(), 2) is None
+
+
+class TestFProc:
+
+
+    @pytest.fixture
+    def collection(self):
+        coll = []
+        for i in range(3):
+            coll.append(DummyData(50 + i))
+        data_coll = DataCollection(collection=coll)
+        return data_coll
+
+    @pytest.fixture
+    def collection_small(self):
+        coll = []
+        for i in range(3):
+            coll.append(DummyData(5 + i))
+        data_coll = DataCollection(collection=coll)
+        return data_coll
+
+    @pytest.fixture
+    def path(self):
+        p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+        shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None
+        os.makedirs(p)
+        yield p
+        shutil.rmtree(p, ignore_errors=True)
+
+    def test_f_proc(self, collection, path):
+        data = collection[0]
+        upsampling = False
+        mod_rank = 2
+        batch_size = 32
+        remaining = f_proc(data, upsampling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0)
+        assert isinstance(remaining, tuple)
+        assert len(remaining) == 2
+        assert isinstance(remaining[0], list)
+        assert len(remaining[0]) == 3
+        assert remaining[0][0].shape == (18, 14, 5)
+
+    def test_f_proc_no_remaining(self, collection, path):
+        data = collection[0]
+        upsampling = False
+        mod_rank = 2
+        batch_size = 50
+        remaining = f_proc(data, upsampling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0)
+        assert remaining is None
+
+    def test_f_proc_X_Y(self, collection, path):
+        data = collection[0]
+        X, Y = data.get_data()
+        upsamling = False
+        mod_rank = 2
+        batch_size = 40
+        remaining = f_proc((X, Y), upsamling, mod_rank, batch_size, os.path.join(path, "%i.pickle"), 0)
+        assert remaining[0][0].shape == (10, 14, 5)