Skip to content
Snippets Groups Projects
Commit 34277ea7 authored by lukas leufen's avatar lukas leufen
Browse files

include data permutation, /close #57

parents 029ffbb0 fc298e80
Branches
Tags
1 merge request!50release for v0.7.0
......@@ -12,11 +12,12 @@ import numpy as np
class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
fit_call: bool = True):
fit_call: bool = True, permute_data: bool = False):
self.generator = generator
self.model = model
self.batch_size = batch_size
self.fit_call = fit_call
self.do_data_permutation = permute_data
def _get_model_rank(self):
mod_out = self.model.output_shape
......@@ -33,6 +34,16 @@ class Distributor(keras.utils.Sequence):
def _get_number_of_mini_batches(self, values):
return math.ceil(values[0].shape[0] / self.batch_size)
def _permute_data(self, x, y):
"""
Permute inputs x and labels y
"""
if self.do_data_permutation:
p = np.random.permutation(len(x)) # equiv to .shape[0]
x = x[p]
y = y[p]
return x, y
def distribute_on_batches(self, fit_call=True):
while True:
for k, v in enumerate(self.generator):
......@@ -42,6 +53,8 @@ class Distributor(keras.utils.Sequence):
num_mini_batches = self._get_number_of_mini_batches(v)
x_total = np.copy(v[0])
y_total = np.copy(v[1])
# permute order for mini-batches
x_total, y_total = self._permute_data(x_total, y_total)
for prev, curr in enumerate(range(1, num_mini_batches+1)):
x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
......
......@@ -38,6 +38,7 @@ class TestDistributor:
def test_init_defaults(self, distributor):
assert distributor.batch_size == 256
assert distributor.fit_call is True
assert distributor.do_data_permutation is False
def test_get_model_rank(self, distributor, model_with_minor_branch):
assert distributor._get_model_rank() == 1
......@@ -73,3 +74,28 @@ class TestDistributor:
d = Distributor(gen, model)
expected = math.ceil(len(gen[0][0]) / 256) + math.ceil(len(gen[1][0]) / 256)
assert len(d) == expected
def test_permute_data_no_permutation(self, distributor):
x = np.array(range(20)).reshape(2, 10).T
y = np.array(range(10)).reshape(10, 1)
x_perm, y_perm = distributor._permute_data(x, y)
assert np.testing.assert_equal(x, x_perm) is None
assert np.testing.assert_equal(y, y_perm) is None
def test_permute_data(self, distributor):
x = np.array(range(20)).reshape(2, 10).T
y = np.array(range(10)).reshape(10, 1)
distributor.do_data_permutation = True
x_perm, y_perm = distributor._permute_data(x, y)
assert x_perm[0, 0] == y_perm[0]
assert x_perm[0, 1] == y_perm[0] + 10
assert x_perm[5, 0] == y_perm[5]
assert x_perm[5, 1] == y_perm[5] + 10
assert x_perm[-1, 0] == y_perm[-1]
assert x_perm[-1, 1] == y_perm[-1] + 10
# resort x_perm and compare if equal to x
x_perm.sort(axis=0)
y_perm.sort(axis=0)
assert np.testing.assert_equal(x, x_perm) is None
assert np.testing.assert_equal(y, y_perm) is None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment