Permutate data for minibatches
Duplicate of #47 (closed)
Data of one station should be permutated before creating mini-batches. Each minibatch should in best practice be a representative sample out of the full data distribution. None permutated data would lead to "seasonal batches" for example only consisting of DJF or MAM...
-
create new index ordering by using np.random.permutation -
permutate inputs (x_total) using new indices -
permutate labels (y_total) using the same indices like for x_total
Permutation should happen directly before mini-batch loop in distribute_on_batches
(data_distributor.py).
The method could look like this:
def distribute_on_batches(self, fit_call=True):
while True:
for k, v in enumerate(self.generator):
# get rank of output
mod_rank = self._get_model_rank()
# get number of mini batches
num_mini_batches = self._get_number_of_mini_batches(v)
x_total = np.copy(v[0])
y_total = np.copy(v[1])
### BEGIN NEW CODE ###
p = np.random.permutation(len(x_total))
x_total = x_total[p]
y_total = y_total[p]
### END NEW CODE ###
for prev, curr in enumerate(range(1, num_mini_batches+1)):
x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
if x is not None: # pragma: no branch
yield (x, y)
if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
return
Edited by Ghost User