From fc298e803deacebcfc99a736e500838f0053911f Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Mon, 2 Mar 2020 14:32:56 +0100 Subject: [PATCH] set default for permutation to False, added some tests --- src/data_handling/data_distributor.py | 14 ++++++------- .../test_data_distributor.py | 20 ++++++++++++++----- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index c750c587..8a872997 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -12,12 +12,12 @@ import numpy as np class Distributor(keras.utils.Sequence): def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, - fit_call: bool = True, permutate_data: bool = True): + fit_call: bool = True, permute_data: bool = False): self.generator = generator self.model = model self.batch_size = batch_size self.fit_call = fit_call - self.permutate_data = permutate_data + self.do_data_permutation = permute_data def _get_model_rank(self): mod_out = self.model.output_shape @@ -34,11 +34,11 @@ class Distributor(keras.utils.Sequence): def _get_number_of_mini_batches(self, values): return math.ceil(values[0].shape[0] / self.batch_size) - def _permutate_data(self, x, y): + def _permute_data(self, x, y): """ - Permutate inputs x and labels y + Permute inputs x and labels y """ - if self.permutate_data: + if self.do_data_permutation: p = np.random.permutation(len(x)) # equiv to .shape[0] x = x[p] y = y[p] @@ -53,8 +53,8 @@ class Distributor(keras.utils.Sequence): num_mini_batches = self._get_number_of_mini_batches(v) x_total = np.copy(v[0]) y_total = np.copy(v[1]) - # permutate order for mini-batches - x_total, y_total = self._permutate_data(x_total, y_total) + # permute order for mini-batches + x_total, y_total = self._permute_data(x_total, y_total) for prev, curr in enumerate(range(1, num_mini_batches+1)): x = x_total[prev*self.batch_size:curr*self.batch_size, ...] y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 776ac000..109a233e 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -38,6 +38,7 @@ class TestDistributor: def test_init_defaults(self, distributor): assert distributor.batch_size == 256 assert distributor.fit_call is True + assert distributor.do_data_permutation is False def test_get_model_rank(self, distributor, model_with_minor_branch): assert distributor._get_model_rank() == 1 @@ -74,18 +75,27 @@ class TestDistributor: expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) assert len(d) == expected - def test_permutate_data(self, distributor): + def test_permute_data_no_permutation(self, distributor): x = np.array(range(20)).reshape(2, 10).T y = np.array(range(10)).reshape(10, 1) - x_perm, y_perm = distributor._permutate_data(x, y) + x_perm, y_perm = distributor._permute_data(x, y) + assert np.testing.assert_equal(x, x_perm) is None + assert np.testing.assert_equal(y, y_perm) is None + + def test_permute_data(self, distributor): + x = np.array(range(20)).reshape(2, 10).T + y = np.array(range(10)).reshape(10, 1) + distributor.do_data_permutation = True + x_perm, y_perm = distributor._permute_data(x, y) assert x_perm[0, 0] == y_perm[0] assert x_perm[0, 1] == y_perm[0] + 10 assert x_perm[5, 0] == y_perm[5] assert x_perm[5, 1] == y_perm[5] + 10 assert x_perm[-1, 0] == y_perm[-1] assert x_perm[-1, 1] == y_perm[-1] + 10 - #resort x_perm and compare if equal to x + # resort x_perm and compare if equal to x x_perm.sort(axis=0) y_perm.sort(axis=0) - assert (x == x_perm).all() - assert (y == y_perm).all() + assert np.testing.assert_equal(x, x_perm) is None + assert np.testing.assert_equal(y, y_perm) is None + -- GitLab