From 3eb24dcd71d73b3be30ab7eb40fc0e79ecc6a265 Mon Sep 17 00:00:00 2001 From: Felix Kleinert <f.kleinert@fz-juelich.de> Date: Tue, 18 Feb 2020 11:16:08 +0100 Subject: [PATCH] related to #57 - add switch into __init__ method - _permutate_data() is no longer a static method --- src/data_handling/data_distributor.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index ece2589a..c750c587 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -12,11 +12,12 @@ import numpy as np class Distributor(keras.utils.Sequence): def __init__(self, generator: keras.utils.Sequence, model: keras.models, batch_size: int = 256, - fit_call: bool = True): + fit_call: bool = True, permutate_data: bool = True): self.generator = generator self.model = model self.batch_size = batch_size self.fit_call = fit_call + self.permutate_data = permutate_data def _get_model_rank(self): mod_out = self.model.output_shape @@ -33,14 +34,14 @@ class Distributor(keras.utils.Sequence): def _get_number_of_mini_batches(self, values): return math.ceil(values[0].shape[0] / self.batch_size) - @staticmethod - def _permutate_data(x, y): + def _permutate_data(self, x, y): """ Permutate inputs x and labels y """ - p = np.random.permutation(len(x)) # equiv to .shape[0] - x = x[p] - y = y[p] + if self.permutate_data: + p = np.random.permutation(len(x)) # equiv to .shape[0] + x = x[p] + y = y[p] return x, y def distribute_on_batches(self, fit_call=True): -- GitLab