From fc298e803deacebcfc99a736e500838f0053911f Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 2 Mar 2020 14:32:56 +0100
Subject: [PATCH] set default for permutation to False, added some tests

---
 src/data_handling/data_distributor.py         | 14 ++++++-------
 .../test_data_distributor.py                  | 20 ++++++++++++++-----
 2 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py
index c750c587..8a872997 100644
--- a/src/data_handling/data_distributor.py
+++ b/src/data_handling/data_distributor.py
@@ -12,12 +12,12 @@ import numpy as np
 class Distributor(keras.utils.Sequence):
 
     def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
-                 fit_call: bool = True, permutate_data: bool = True):
+                 fit_call: bool = True, permute_data: bool = False):
         self.generator = generator
         self.model = model
         self.batch_size = batch_size
         self.fit_call = fit_call
-        self.permutate_data = permutate_data
+        self.do_data_permutation = permute_data
 
     def _get_model_rank(self):
         mod_out = self.model.output_shape
@@ -34,11 +34,11 @@ class Distributor(keras.utils.Sequence):
     def _get_number_of_mini_batches(self, values):
         return math.ceil(values[0].shape[0] / self.batch_size)
 
-    def _permutate_data(self, x, y):
+    def _permute_data(self, x, y):
         """
-        Permutate inputs x and labels y
+        Permute inputs x and labels y
         """
-        if self.permutate_data:
+        if self.do_data_permutation:
             p = np.random.permutation(len(x))  # equiv to .shape[0]
             x = x[p]
             y = y[p]
@@ -53,8 +53,8 @@ class Distributor(keras.utils.Sequence):
                 num_mini_batches = self._get_number_of_mini_batches(v)
                 x_total = np.copy(v[0])
                 y_total = np.copy(v[1])
-                # permutate order for mini-batches
-                x_total, y_total = self._permutate_data(x_total, y_total)
+                # permute order for mini-batches
+                x_total, y_total = self._permute_data(x_total, y_total)
                 for prev, curr in enumerate(range(1, num_mini_batches+1)):
                     x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
                     y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index 776ac000..109a233e 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -38,6 +38,7 @@ class TestDistributor:
     def test_init_defaults(self, distributor):
         assert distributor.batch_size == 256
         assert distributor.fit_call is True
+        assert distributor.do_data_permutation is False
 
     def test_get_model_rank(self, distributor, model_with_minor_branch):
         assert distributor._get_model_rank() == 1
@@ -74,18 +75,27 @@ class TestDistributor:
         expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
         assert len(d) == expected
 
-    def test_permutate_data(self, distributor):
+    def test_permute_data_no_permutation(self, distributor):
         x = np.array(range(20)).reshape(2, 10).T
         y = np.array(range(10)).reshape(10, 1)
-        x_perm, y_perm = distributor._permutate_data(x, y)
+        x_perm, y_perm = distributor._permute_data(x, y)
+        assert np.testing.assert_equal(x, x_perm) is None
+        assert np.testing.assert_equal(y, y_perm) is None
+
+    def test_permute_data(self, distributor):
+        x = np.array(range(20)).reshape(2, 10).T
+        y = np.array(range(10)).reshape(10, 1)
+        distributor.do_data_permutation = True
+        x_perm, y_perm = distributor._permute_data(x, y)
         assert x_perm[0, 0] == y_perm[0]
         assert x_perm[0, 1] == y_perm[0] + 10
         assert x_perm[5, 0] == y_perm[5]
         assert x_perm[5, 1] == y_perm[5] + 10
         assert x_perm[-1, 0] == y_perm[-1]
         assert x_perm[-1, 1] == y_perm[-1] + 10
-        #resort x_perm and compare if equal to x
+        # resort x_perm and compare if equal to x
         x_perm.sort(axis=0)
         y_perm.sort(axis=0)
-        assert (x == x_perm).all()
-        assert (y == y_perm).all()
+        assert np.testing.assert_equal(x, x_perm) is None
+        assert np.testing.assert_equal(y, y_perm) is None
+
-- 
GitLab