From dc525c313f91d83b7f2c96f543938f93ceb5e13c Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Tue, 18 Feb 2020 10:36:26 +0100 Subject: [PATCH] related to #57 - add static method to permutate inputs and labels - call static method within distribute on batches --- src/data_handling/data_distributor.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index c6f38a6f..ece2589a 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -33,6 +33,16 @@ class Distributor(keras.utils.Sequence): def _get_number_of_mini_batches(self, values): return math.ceil(values[0].shape[0] / self.batch_size) + @staticmethod + def _permutate_data(x, y): + """ + Permutate inputs x and labels y + """ + p = np.random.permutation(len(x)) # equiv to .shape[0] + x = x[p] + y = y[p] + return x, y + def distribute_on_batches(self, fit_call=True): while True: for k, v in enumerate(self.generator): @@ -42,6 +52,8 @@ class Distributor(keras.utils.Sequence): num_mini_batches = self._get_number_of_mini_batches(v) x_total = np.copy(v[0]) y_total = np.copy(v[1]) + # permutate order for mini-batches + x_total, y_total = self._permutate_data(x_total, y_total) for prev, curr in enumerate(range(1, num_mini_batches+1)): x = x_total[prev*self.batch_size:curr*self.batch_size, ...] y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] -- GitLab