diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 4c6dbb1c38f2e4a49e53883fbe3cb33cb565118a..776ac000824f8f68712aa6ee2fd8096eb70f6829 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -73,3 +73,19 @@ class TestDistributor: d = Distributor(gen, model) expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256) assert len(d) == expected + + def test_permutate_data(self, distributor): + x = np.array(range(20)).reshape(2, 10).T + y = np.array(range(10)).reshape(10, 1) + x_perm, y_perm = distributor._permutate_data(x, y) + assert x_perm[0, 0] == y_perm[0] + assert x_perm[0, 1] == y_perm[0] + 10 + assert x_perm[5, 0] == y_perm[5] + assert x_perm[5, 1] == y_perm[5] + 10 + assert x_perm[-1, 0] == y_perm[-1] + assert x_perm[-1, 1] == y_perm[-1] + 10 + #resort x_perm and compare if equal to x + x_perm.sort(axis=0) + y_perm.sort(axis=0) + assert (x == x_perm).all() + assert (y == y_perm).all()