Skip to content
Snippets Groups Projects
Commit 3eb24dcd authored by Felix Kleinert's avatar Felix Kleinert
Browse files

related to #57

- add switch into __init__ method
- _permutate_data() is no longer a static method
parent dc525c31
No related branches found
No related tags found
2 merge requests!50release for v0.7.0,!44Felix issue057 permutate data for minibatches
Pipeline #29815 passed
...@@ -12,11 +12,12 @@ import numpy as np ...@@ -12,11 +12,12 @@ import numpy as np
class Distributor(keras.utils.Sequence): class Distributor(keras.utils.Sequence):
def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256,
fit_call: bool = True): fit_call: bool = True, permutate_data: bool = True):
self.generator = generator self.generator = generator
self.model = model self.model = model
self.batch_size = batch_size self.batch_size = batch_size
self.fit_call = fit_call self.fit_call = fit_call
self.permutate_data = permutate_data
def _get_model_rank(self): def _get_model_rank(self):
mod_out = self.model.output_shape mod_out = self.model.output_shape
...@@ -33,11 +34,11 @@ class Distributor(keras.utils.Sequence): ...@@ -33,11 +34,11 @@ class Distributor(keras.utils.Sequence):
def _get_number_of_mini_batches(self, values): def _get_number_of_mini_batches(self, values):
return math.ceil(values[0].shape[0] / self.batch_size) return math.ceil(values[0].shape[0] / self.batch_size)
@staticmethod def _permutate_data(self, x, y):
def _permutate_data(x, y):
""" """
Permutate inputs x and labels y Permutate inputs x and labels y
""" """
if self.permutate_data:
p = np.random.permutation(len(x)) # equiv to .shape[0] p = np.random.permutation(len(x)) # equiv to .shape[0]
x = x[p] x = x[p]
y = y[p] y = y[p]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment