diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index ece2589aa78a0185a2edc0fd5f628875da3fd012..c750c5870f38a7378ed36a29f9052363759b98ea 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -12,11 +12,12 @@ import numpy as np class Distributor(keras.utils.Sequence): def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, - fit_call: bool = True): + fit_call: bool = True, permutate_data: bool = True): self.generator = generator self.model = model self.batch_size = batch_size self.fit_call = fit_call + self.permutate_data = permutate_data def _get_model_rank(self): mod_out = self.model.output_shape @@ -33,14 +34,14 @@ class Distributor(keras.utils.Sequence): def _get_number_of_mini_batches(self, values): return math.ceil(values[0].shape[0] / self.batch_size) - @staticmethod - def _permutate_data(x, y): + def _permutate_data(self, x, y): """ Permutate inputs x and labels y """ - p = np.random.permutation(len(x)) # equiv to .shape[0] - x = x[p] - y = y[p] + if self.permutate_data: + p = np.random.permutation(len(x)) # equiv to .shape[0] + x = x[p] + y = y[p] return x, y def distribute_on_batches(self, fit_call=True):