diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index c6f38a6f0e70518956bcbbd51a6fdfc1a1e7849f..8a872997877536f948483b66c90db30c1c849f3d 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -12,11 +12,12 @@ import numpy as np class Distributor(keras.utils.Sequence): def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, - fit_call: bool = True): + fit_call: bool = True, permute_data: bool = False): self.generator = generator self.model = model self.batch_size = batch_size self.fit_call = fit_call + self.do_data_permutation = permute_data def _get_model_rank(self): mod_out = self.model.output_shape @@ -33,6 +34,16 @@ class Distributor(keras.utils.Sequence): def _get_number_of_mini_batches(self, values): return math.ceil(values[0].shape[0] / self.batch_size) + def _permute_data(self, x, y): + """ + Permute inputs x and labels y + """ + if self.do_data_permutation: + p = np.random.permutation(len(x)) # equiv to .shape[0] + x = x[p] + y = y[p] + return x, y + def distribute_on_batches(self, fit_call=True): while True: for k, v in enumerate(self.generator): @@ -42,6 +53,8 @@ class Distributor(keras.utils.Sequence): num_mini_batches = self._get_number_of_mini_batches(v) x_total = np.copy(v[0]) y_total = np.copy(v[1]) + # permute order for mini-batches + x_total, y_total = self._permute_data(x_total, y_total) for prev, curr in enumerate(range(1, num_mini_batches+1)): x = x_total[prev*self.batch_size:curr*self.batch_size, ...] y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 4c6dbb1c38f2e4a49e53883fbe3cb33cb565118a..109a233ebe4d354bc03359cf5acec81d0f8ebac0 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -38,6 +38,7 @@ class TestDistributor: def test_init_defaults(self, distributor): assert distributor.batch_size == 256 assert distributor.fit_call is True + assert distributor.do_data_permutation is False def test_get_model_rank(self, distributor, model_with_minor_branch): assert distributor._get_model_rank() == 1 @@ -73,3 +74,28 @@ class TestDistributor: d = Distributor(gen, model) expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) assert len(d) == expected + + def test_permute_data_no_permutation(self, distributor): + x = np.array(range(20)).reshape(2, 10).T + y = np.array(range(10)).reshape(10, 1) + x_perm, y_perm = distributor._permute_data(x, y) + assert np.testing.assert_equal(x, x_perm) is None + assert np.testing.assert_equal(y, y_perm) is None + + def test_permute_data(self, distributor): + x = np.array(range(20)).reshape(2, 10).T + y = np.array(range(10)).reshape(10, 1) + distributor.do_data_permutation = True + x_perm, y_perm = distributor._permute_data(x, y) + assert x_perm[0, 0] == y_perm[0] + assert x_perm[0, 1] == y_perm[0] + 10 + assert x_perm[5, 0] == y_perm[5] + assert x_perm[5, 1] == y_perm[5] + 10 + assert x_perm[-1, 0] == y_perm[-1] + assert x_perm[-1, 1] == y_perm[-1] + 10 + # resort x_perm and compare if equal to x + x_perm.sort(axis=0) + y_perm.sort(axis=0) + assert np.testing.assert_equal(x, x_perm) is None + assert np.testing.assert_equal(y, y_perm) is None +