Skip to content
Snippets Groups Projects
Commit 545146bb authored by Felix Kleinert's avatar Felix Kleinert
Browse files

add test_permutate_data for issue #57

parent 3eb24dcd
Branches
Tags
2 merge requests!50release for v0.7.0,!44Felix issue057 permutate data for minibatches
Pipeline #30717 passed
......@@ -73,3 +73,19 @@ class TestDistributor:
d = Distributor(gen, model)
expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
assert len(d) == expected
def test_permutate_data(self, distributor):
x = np.array(range(20)).reshape(2, 10).T
y = np.array(range(10)).reshape(10, 1)
x_perm, y_perm = distributor._permutate_data(x, y)
assert x_perm[0, 0] == y_perm[0]
assert x_perm[0, 1] == y_perm[0] + 10
assert x_perm[5, 0] == y_perm[5]
assert x_perm[5, 1] == y_perm[5] + 10
assert x_perm[-1, 0] == y_perm[-1]
assert x_perm[-1, 1] == y_perm[-1] + 10
#resort x_perm and compare if equal to x
x_perm.sort(axis=0)
y_perm.sort(axis=0)
assert (x == x_perm).all()
assert (y == y_perm).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment