From 545146bbe16a7610b272879495342d4e47e4bfa4 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Fri, 28 Feb 2020 11:34:31 +0100
Subject: [PATCH] add test_permutate_data for issue #57

---
 test/test_data_handling/test_data_distributor.py | 16 ++++++++++++++++
 1 file changed, 16 insertions(+)

diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py
index 4c6dbb1c..776ac000 100644
--- a/test/test_data_handling/test_data_distributor.py
+++ b/test/test_data_handling/test_data_distributor.py
@@ -73,3 +73,19 @@ class TestDistributor:
         d = Distributor(gen, model)
         expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
         assert len(d) == expected
+
+    def test_permutate_data(self, distributor):
+        x = np.array(range(20)).reshape(2, 10).T
+        y = np.array(range(10)).reshape(10, 1)
+        x_perm, y_perm = distributor._permutate_data(x, y)
+        assert x_perm[0, 0] == y_perm[0]
+        assert x_perm[0, 1] == y_perm[0] + 10
+        assert x_perm[5, 0] == y_perm[5]
+        assert x_perm[5, 1] == y_perm[5] + 10
+        assert x_perm[-1, 0] == y_perm[-1]
+        assert x_perm[-1, 1] == y_perm[-1] + 10
+        #resort x_perm and compare if equal to x
+        x_perm.sort(axis=0)
+        y_perm.sort(axis=0)
+        assert (x == x_perm).all()
+        assert (y == y_perm).all()
-- 
GitLab