Skip to content
Snippets Groups Projects
Commit fc298e80 authored by lukas leufen's avatar lukas leufen
Browse files

set default for permutation to False, added some tests

parent 545146bb
Branches
Tags
2 merge requests!50release for v0.7.0,!44Felix issue057 permutate data for minibatches
Pipeline #30911 passed
......@@ -12,12 +12,12 @@ import numpy as np
class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
fit_call: bool = True, permutate_data: bool = True):
fit_call: bool = True, permute_data: bool = False):
self.generator = generator
self.model = model
self.batch_size = batch_size
self.fit_call = fit_call
self.permutate_data = permutate_data
self.do_data_permutation = permute_data
def _get_model_rank(self):
mod_out = self.model.output_shape
......@@ -34,11 +34,11 @@ class Distributor(keras.utils.Sequence):
def _get_number_of_mini_batches(self, values):
return math.ceil(values[0].shape[0] / self.batch_size)
def _permutate_data(self, x, y):
def _permute_data(self, x, y):
"""
Permutate inputs x and labels y
Permute inputs x and labels y
"""
if self.permutate_data:
if self.do_data_permutation:
p = np.random.permutation(len(x)) # equiv to .shape[0]
x = x[p]
y = y[p]
......@@ -53,8 +53,8 @@ class Distributor(keras.utils.Sequence):
num_mini_batches = self._get_number_of_mini_batches(v)
x_total = np.copy(v[0])
y_total = np.copy(v[1])
# permutate order for mini-batches
x_total, y_total = self._permutate_data(x_total, y_total)
# permute order for mini-batches
x_total, y_total = self._permute_data(x_total, y_total)
for prev, curr in enumerate(range(1, num_mini_batches+1)):
x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
......
......@@ -38,6 +38,7 @@ class TestDistributor:
def test_init_defaults(self, distributor):
assert distributor.batch_size == 256
assert distributor.fit_call is True
assert distributor.do_data_permutation is False
def test_get_model_rank(self, distributor, model_with_minor_branch):
assert distributor._get_model_rank() == 1
......@@ -74,10 +75,18 @@ class TestDistributor:
expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
assert len(d) == expected
def test_permutate_data(self, distributor):
def test_permute_data_no_permutation(self, distributor):
x = np.array(range(20)).reshape(2, 10).T
y = np.array(range(10)).reshape(10, 1)
x_perm, y_perm = distributor._permutate_data(x, y)
x_perm, y_perm = distributor._permute_data(x, y)
assert np.testing.assert_equal(x, x_perm) is None
assert np.testing.assert_equal(y, y_perm) is None
def test_permute_data(self, distributor):
x = np.array(range(20)).reshape(2, 10).T
y = np.array(range(10)).reshape(10, 1)
distributor.do_data_permutation = True
x_perm, y_perm = distributor._permute_data(x, y)
assert x_perm[0, 0] == y_perm[0]
assert x_perm[0, 1] == y_perm[0] + 10
assert x_perm[5, 0] == y_perm[5]
......@@ -87,5 +96,6 @@ class TestDistributor:
# resort x_perm and compare if equal to x
x_perm.sort(axis=0)
y_perm.sort(axis=0)
assert (x == x_perm).all()
assert (y == y_perm).all()
assert np.testing.assert_equal(x, x_perm) is None
assert np.testing.assert_equal(y, y_perm) is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment