Skip to content
Snippets Groups Projects
Commit fc298e80 authored by lukas leufen's avatar lukas leufen
Browse files

set default for permutation to False, added some tests

parent 545146bb
No related branches found
No related tags found
2 merge requests!50release for v0.7.0,!44Felix issue057 permutate data for minibatches
Pipeline #30911 passed
......@@ -12,12 +12,12 @@ import numpy as np
class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
fit_call: bool = True, permutate_data: bool = True):
fit_call: bool = True, permute_data: bool = False):
self.generator = generator
self.model = model
self.batch_size = batch_size
self.fit_call = fit_call
self.permutate_data = permutate_data
self.do_data_permutation = permute_data
def _get_model_rank(self):
mod_out = self.model.output_shape
......@@ -34,11 +34,11 @@ class Distributor(keras.utils.Sequence):
def _get_number_of_mini_batches(self, values):
return math.ceil(values[0].shape[0] / self.batch_size)
def _permutate_data(self, x, y):
def _permute_data(self, x, y):
"""
Permutate inputs x and labels y
Permute inputs x and labels y
"""
if self.permutate_data:
if self.do_data_permutation:
p = np.random.permutation(len(x)) # equiv to .shape[0]
x = x[p]
y = y[p]
......@@ -53,8 +53,8 @@ class Distributor(keras.utils.Sequence):
num_mini_batches = self._get_number_of_mini_batches(v)
x_total = np.copy(v[0])
y_total = np.copy(v[1])
# permutate order for mini-batches
x_total, y_total = self._permutate_data(x_total, y_total)
# permute order for mini-batches
x_total, y_total = self._permute_data(x_total, y_total)
for prev, curr in enumerate(range(1, num_mini_batches+1)):
x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
......
......@@ -38,6 +38,7 @@ class TestDistributor:
def test_init_defaults(self, distributor):
assert distributor.batch_size == 256
assert distributor.fit_call is True
assert distributor.do_data_permutation is False
def test_get_model_rank(self, distributor, model_with_minor_branch):
assert distributor._get_model_rank() == 1
......@@ -74,10 +75,18 @@ class TestDistributor:
expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
assert len(d) == expected
def test_permutate_data(self, distributor):
def test_permute_data_no_permutation(self, distributor):
x = np.array(range(20)).reshape(2, 10).T
y = np.array(range(10)).reshape(10, 1)
x_perm, y_perm = distributor._permutate_data(x, y)
x_perm, y_perm = distributor._permute_data(x, y)
assert np.testing.assert_equal(x, x_perm) is None
assert np.testing.assert_equal(y, y_perm) is None
def test_permute_data(self, distributor):
x = np.array(range(20)).reshape(2, 10).T
y = np.array(range(10)).reshape(10, 1)
distributor.do_data_permutation = True
x_perm, y_perm = distributor._permute_data(x, y)
assert x_perm[0, 0] == y_perm[0]
assert x_perm[0, 1] == y_perm[0] + 10
assert x_perm[5, 0] == y_perm[5]
......@@ -87,5 +96,6 @@ class TestDistributor:
# resort x_perm and compare if equal to x
x_perm.sort(axis=0)
y_perm.sort(axis=0)
assert (x == x_perm).all()
assert (y == y_perm).all()
assert np.testing.assert_equal(x, x_perm) is None
assert np.testing.assert_equal(y, y_perm) is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment