diff --git a/mlair/data_handler/iterator.py b/mlair/data_handler/iterator.py
index 3fc25a90f861c65d38aa6b7019095210035d4c2d..83090d6f5f1a79b48e37cf8c38fd4bbd804031a4 100644
--- a/mlair/data_handler/iterator.py
+++ b/mlair/data_handler/iterator.py
@@ -8,10 +8,12 @@ import numpy as np
 import math
 import os
 import shutil
-import pickle
+import psutil
+import multiprocessing
 import logging
 import dill
 from typing import Tuple, List
+from mlair.helpers import TimeTrackingWrapper
 
 
 class StandardIterator(Iterator):
@@ -75,7 +77,7 @@ class DataCollection(Iterable):
 class KerasIterator(keras.utils.Sequence):
 
     def __init__(self, collection: DataCollection, batch_size: int, batch_path: str, shuffle_batches: bool = False,
-                 model=None, upsampling=False, name=None):
+                 model=None, upsampling=False, name=None, use_multiprocessing=False, max_number_multiprocessing=1):
         self._collection = collection
         batch_path = os.path.join(batch_path, str(name if name is not None else id(self)))
         self._path = os.path.join(batch_path, "%i.pickle")
@@ -85,7 +87,7 @@ class KerasIterator(keras.utils.Sequence):
         self.upsampling = upsampling
         self.indexes: list = []
         self._cleanup_path(batch_path)
-        self._prepare_batches()
+        self._prepare_batches(use_multiprocessing, max_number_multiprocessing)
 
     def __len__(self) -> int:
         return len(self.indexes)
@@ -130,7 +132,8 @@ class KerasIterator(keras.utils.Sequence):
         Y = list(map(lambda x: x[p], Y))
         return X, Y
 
-    def _prepare_batches(self) -> None:
+    @TimeTrackingWrapper
+    def _prepare_batches(self, use_multiprocessing=False, max_process=1) -> None:
         """
         Prepare all batches as locally stored files.
 
@@ -142,6 +145,12 @@ class KerasIterator(keras.utils.Sequence):
         index = 0
         remaining = None
         mod_rank = self._get_model_rank()
+        # max_process = 12
+        n_process = min([psutil.cpu_count(logical=False), len(self._collection), max_process])  # use only physical cpus
+        if n_process > 1 and use_multiprocessing is True:  # parallel solution
+            pool = multiprocessing.Pool(n_process)
+        else:
+            pool = None
         for data in self._collection:
             logging.debug(f"prepare batches for {str(data)}")
             X, _Y = data.get_data(upsampling=self.upsampling)
@@ -152,18 +161,28 @@ class KerasIterator(keras.utils.Sequence):
                 X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1])
             length = X[0].shape[0]
             batches = self._get_number_of_mini_batches(length)
+            output = []
             for b in range(batches):
-                batch_X, batch_Y = self._get_batch(X, b), self._get_batch(Y, b)
-                self._save_to_pickle(X=batch_X, Y=batch_Y, index=index)
+                if pool is None:
+                    output.append(f_proc_keras_gen(X, Y, b, self.batch_size, index, self._path))
+                else:
+                    output.append(pool.apply_async(f_proc_keras_gen, args=(X, Y, b, self.batch_size, index, self._path)))
                 index += 1
+            if pool is not None:
+                [p.get() for p in output]
             if (batches * self.batch_size) < length:  # keep remaining to concatenate with next data element
-                remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
+                remaining = (get_batch(X, batches, self.batch_size), get_batch(Y, batches, self.batch_size))
+                # remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
             else:
                 remaining = None
         if remaining is not None:  # add remaining as smaller batch
-            self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
+            # self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
+            save_to_pickle(self._path, X=remaining[0], Y=remaining[1], index=index)
             index += 1
         self.indexes = np.arange(0, index).tolist()
+        if pool is not None:
+            pool.close()
+            pool.join()
 
     def _save_to_pickle(self, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None:
         """Save data as pickle file with variables X and Y and given index as <index>.pickle ."""
@@ -188,3 +207,21 @@ class KerasIterator(keras.utils.Sequence):
         """Randomly shuffle indexes if enabled."""
         if self.shuffle is True:
             np.random.shuffle(self.indexes)
+
+
+def f_proc_keras_gen(X, Y, batch_number, batch_size, index, path):
+    batch_X, batch_Y = get_batch(X, batch_number, batch_size), get_batch(Y, batch_number, batch_size)
+    save_to_pickle(path, X=batch_X, Y=batch_Y, index=index)
+
+
+def save_to_pickle(path, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None:
+    """Save data as pickle file with variables X and Y and given index as <index>.pickle ."""
+    data = {"X": X, "Y": Y}
+    file = path % index
+    with open(file, "wb") as f:
+        dill.dump(data, f)
+
+
+def get_batch(data_list: List[np.ndarray], b: int, batch_size: int) -> List[np.ndarray]:
+    """Get batch according to batch size from data list."""
+    return list(map(lambda data: data[b * batch_size:(b + 1) * batch_size, ...], data_list))
diff --git a/mlair/run_modules/training.py b/mlair/run_modules/training.py
index 5ce906122ef184d6dcad5527e923e44f04028fe5..7264d58c3a05ba242ef52a250005d37df9c0745f 100644
--- a/mlair/run_modules/training.py
+++ b/mlair/run_modules/training.py
@@ -109,7 +109,8 @@ class Training(RunEnvironment):
         :param mode: name of set, should be from ["train", "val", "test"]
         """
         collection = self.data_store.get("data_collection", mode)
-        kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path"], scope=mode)
+        kwargs = self.data_store.create_args_dict(["upsampling", "shuffle_batches", "batch_path", "use_multiprocessing", "max_number_multiprocessing"], scope=mode)
+
         setattr(self, f"{mode}_set", KerasIterator(collection, self.batch_size, model=self.model, name=mode, **kwargs))
 
     def set_generators(self) -> None: