diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index c6f38a6f0e70518956bcbbd51a6fdfc1a1e7849f..8a872997877536f948483b66c90db30c1c849f3d 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -12,11 +12,12 @@ import numpy as np
 class Distributor(keras.utils.Sequence):
 
     def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
-                 fit_call: bool = True):
+                 fit_call: bool = True, permute_data: bool = False):
         self.generator = generator
         self.model = model
         self.batch_size = batch_size
         self.fit_call = fit_call
+        self.do_data_permutation = permute_data
 
     def _get_model_rank(self):
         mod_out = self.model.output_shape
@@ -33,6 +34,16 @@ class Distributor(keras.utils.Sequence):
     def _get_number_of_mini_batches(self, values):
         return math.ceil(values[0].shape[0] / self.batch_size)
 
+    def _permute_data(self, x, y):
+        """
+        Permute inputs x and labels y
+        """
+        if self.do_data_permutation:
+            p = np.random.permutation(len(x))  # equiv to .shape[0]
+            x = x[p]
+            y = y[p]
+        return x, y
+
     def distribute_on_batches(self, fit_call=True):
         while True:
             for k, v in enumerate(self.generator):
@@ -42,6 +53,8 @@ class Distributor(keras.utils.Sequence):
                 num_mini_batches = self._get_number_of_mini_batches(v)
                 x_total = np.copy(v[0])
                 y_total = np.copy(v[1])
+                # permute order for mini-batches
+                x_total, y_total = self._permute_data(x_total, y_total)
                 for prev, curr in enumerate(range(1, num_mini_batches+1)):
                     x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
                     y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index 4c6dbb1c38f2e4a49e53883fbe3cb33cb565118a..109a233ebe4d354bc03359cf5acec81d0f8ebac0 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -38,6 +38,7 @@ class TestDistributor:
     def test_init_defaults(self, distributor):
         assert distributor.batch_size == 256
         assert distributor.fit_call is True
+        assert distributor.do_data_permutation is False
 
     def test_get_model_rank(self, distributor, model_with_minor_branch):
         assert distributor._get_model_rank() == 1
@@ -73,3 +74,28 @@ class TestDistributor:
         d = Distributor(gen, model)
         expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
         assert len(d) == expected
+
+    def test_permute_data_no_permutation(self, distributor):
+        x = np.array(range(20)).reshape(2, 10).T
+        y = np.array(range(10)).reshape(10, 1)
+        x_perm, y_perm = distributor._permute_data(x, y)
+        assert np.testing.assert_equal(x, x_perm) is None
+        assert np.testing.assert_equal(y, y_perm) is None
+
+    def test_permute_data(self, distributor):
+        x = np.array(range(20)).reshape(2, 10).T
+        y = np.array(range(10)).reshape(10, 1)
+        distributor.do_data_permutation = True
+        x_perm, y_perm = distributor._permute_data(x, y)
+        assert x_perm[0, 0] == y_perm[0]
+        assert x_perm[0, 1] == y_perm[0] + 10
+        assert x_perm[5, 0] == y_perm[5]
+        assert x_perm[5, 1] == y_perm[5] + 10
+        assert x_perm[-1, 0] == y_perm[-1]
+        assert x_perm[-1, 1] == y_perm[-1] + 10
+        # resort x_perm and compare if equal to x
+        x_perm.sort(axis=0)
+        y_perm.sort(axis=0)
+        assert np.testing.assert_equal(x, x_perm) is None
+        assert np.testing.assert_equal(y, y_perm) is None
+