From 545146bbe16a7610b272879495342d4e47e4bfa4 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Fri, 28 Feb 2020 11:34:31 +0100 Subject: [PATCH] add test_permutate_data for issue #57 --- test/test_data_handling/test_data_distributor.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 4c6dbb1c..776ac000 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -73,3 +73,19 @@ class TestDistributor: d = Distributor(gen, model) expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) assert len(d) == expected + + def test_permutate_data(self, distributor): + x = np.array(range(20)).reshape(2, 10).T + y = np.array(range(10)).reshape(10, 1) + x_perm, y_perm = distributor._permutate_data(x, y) + assert x_perm[0, 0] == y_perm[0] + assert x_perm[0, 1] == y_perm[0] + 10 + assert x_perm[5, 0] == y_perm[5] + assert x_perm[5, 1] == y_perm[5] + 10 + assert x_perm[-1, 0] == y_perm[-1] + assert x_perm[-1, 1] == y_perm[-1] + 10 + #resort x_perm and compare if equal to x + x_perm.sort(axis=0) + y_perm.sort(axis=0) + assert (x == x_perm).all() + assert (y == y_perm).all() -- GitLab