From dc525c313f91d83b7f2c96f543938f93ceb5e13c Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Tue, 18 Feb 2020 10:36:26 +0100
Subject: [PATCH] related to #57 - add static method to permutate inputs and
 labels - call static method within distribute on batches

---
 src/data_handling/data_distributor.py | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index c6f38a6f..ece2589a 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -33,6 +33,16 @@ class Distributor(keras.utils.Sequence):
     def _get_number_of_mini_batches(self, values):
         return math.ceil(values[0].shape[0] / self.batch_size)
 
+    @staticmethod
+    def _permutate_data(x, y):
+        """
+        Permutate inputs x and labels y
+        """
+        p = np.random.permutation(len(x))  # equiv to .shape[0]
+        x = x[p]
+        y = y[p]
+        return x, y
+
     def distribute_on_batches(self, fit_call=True):
         while True:
             for k, v in enumerate(self.generator):
@@ -42,6 +52,8 @@ class Distributor(keras.utils.Sequence):
                 num_mini_batches = self._get_number_of_mini_batches(v)
                 x_total = np.copy(v[0])
                 y_total = np.copy(v[1])
+                # permutate order for mini-batches
+                x_total, y_total = self._permutate_data(x_total, y_total)
                 for prev, curr in enumerate(range(1, num_mini_batches+1)):
                     x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
                     y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
-- 
GitLab