diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py
index 83090d6f5f1a79b48e37cf8c38fd4bbd804031a4..1a3268206b33c4cae583b12102f500f854b82467 100644
--- a/mlair/data_handler/iterator.py
+++ b/mlair/data_handler/iterator.py
@@ -87,7 +87,8 @@ class KerasIterator(keras.utils.Sequence):
         self.upsampling = upsampling
         self.indexes: list = []
         self._cleanup_path(batch_path)
-        self._prepare_batches(use_multiprocessing, max_number_multiprocessing)
+        self._prepare_batches_parallel(use_multiprocessing, max_number_multiprocessing)
+        self._prepare_batches(False, max_number_multiprocessing)
 
     def __len__(self) -> int:
         return len(self.indexes)
@@ -121,6 +122,11 @@ class KerasIterator(keras.utils.Sequence):
         """Concatenate two lists of data along axis=0."""
         return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new))
 
+    @staticmethod
+    def _concatenate_multi(*args: List[np.ndarray]) -> List[np.ndarray]:
+        """Concatenate two lists of data along axis=0."""
+        return list(map(lambda *_args: np.concatenate(_args, axis=0), *args))
+
     def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]:
         """Get batch according to batch size from data list."""
         return list(map(lambda data: data[b * self.batch_size:(b + 1) * self.batch_size, ...], data_list))
@@ -132,6 +138,51 @@ class KerasIterator(keras.utils.Sequence):
         Y = list(map(lambda x: x[p], Y))
         return X, Y
 
+    @TimeTrackingWrapper
+    def _prepare_batches_parallel(self, use_multiprocessing=False, max_process=1) -> None:
+        index = 0
+        remaining = []
+        mod_rank = self._get_model_rank()
+        # max_process = 12
+        n_process = min([psutil.cpu_count(logical=False), len(self._collection), max_process])  # use only physical cpus
+        if n_process > 1 and use_multiprocessing is True:  # parallel solution
+            pool = multiprocessing.Pool(n_process)
+            output = []
+        else:
+            pool = None
+            output = None
+        for data in self._collection:
+            X, _Y = data.get_data(upsampling=self.upsampling)
+            length = X[0].shape[0]
+            batches = _get_number_of_mini_batches(length, self.batch_size)
+            if pool is None:
+                res = f_proc(X, _Y, self.upsampling, mod_rank, self.batch_size, self._path, index)
+                if res is not None:
+                    remaining.append(res)
+            else:
+                output.append(pool.apply_async(f_proc, args=(X, _Y, self.upsampling, mod_rank, self.batch_size, self._path, index)))
+            index += batches
+        if output is not None:
+            for p in output:
+                res = p.get()
+                if res is not None:
+                    remaining.append(res)
+            pool.close()
+        if len(remaining) > 0:
+            X = self._concatenate_multi(*[e[0] for e in remaining])
+            Y = self._concatenate_multi(*[e[1] for e in remaining])
+            length = X[0].shape[0]
+            batches = _get_number_of_mini_batches(length, self.batch_size)
+            remaining = f_proc(X, Y, self.upsampling, mod_rank, self.batch_size, self._path, index)
+            index += batches
+            if remaining is not None:
+                save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index)
+                index += 1
+        self.indexes = np.arange(0, index).tolist()
+        logging.warning(f"hightst index is {index}")
+        if pool is not None:
+            pool.join()
+
     @TimeTrackingWrapper
     def _prepare_batches(self, use_multiprocessing=False, max_process=1) -> None:
         """
@@ -180,6 +231,7 @@ class KerasIterator(keras.utils.Sequence):
             save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index)
             index += 1
         self.indexes = np.arange(0, index).tolist()
+        logging.warning(f"hightst index is {index}")
         if pool is not None:
             pool.close()
             pool.join()
@@ -225,3 +277,30 @@ def save_to_pickle(path, X: List[np.ndarray], Y: List[np.ndarray], index: int) -
 def get_batch(data_list: List[np.ndarray], b: int, batch_size: int) -> List[np.ndarray]:
     """Get batch according to batch size from data list."""
     return list(map(lambda data: data[b * batch_size:(b + 1) * batch_size, ...], data_list))
+
+
+def _permute_data(X, Y):
+    p = np.random.permutation(len(X[0]))  # equiv to .shape[0]
+    X = list(map(lambda x: x[p], X))
+    Y = list(map(lambda x: x[p], Y))
+    return X, Y
+
+def _get_number_of_mini_batches(number_of_samples: int, batch_size: int) -> int:
+    """Return number of mini batches as the floored ration of number of samples to batch size."""
+    return math.floor(number_of_samples / batch_size)
+
+
+def f_proc(X, _Y, upsampling, mod_rank, batch_size, _path, index):
+    Y = [_Y[0] for _ in range(mod_rank)]
+    if upsampling:
+        X, Y = _permute_data(X, Y)
+    length = X[0].shape[0]
+    batches = _get_number_of_mini_batches(length, batch_size)
+    for b in range(batches):
+        f_proc_keras_gen(X, Y, b, batch_size, index, _path)
+        index += 1
+    if (batches * batch_size) < length:  # keep remaining to concatenate with next data element
+        remaining = (get_batch(X, batches, batch_size), get_batch(Y, batches, batch_size))
+    else:
+        remaining = None
+    return remaining