Permutate data for minibatches
Duplicate of #47 (closed)
Data of one station should be permutated before creating mini-batches. Each minibatch should in best practice be a representative sample out of the full data distribution. None permutated data would lead to "seasonal batches" for example only consisting of DJF or MAM...
create new index ordering by using np.random.permutation -
permutate inputs (x_total) using new indices -
permutate labels (y_total) using the same indices like for x_total
Permutation should happen directly before mini-batch loop in distribute_on_batches
The method could look like this:
def distribute_on_batches(self, fit_call=True):
while True:
for k, v in enumerate(self.generator):
# get rank of output
mod_rank = self._get_model_rank()
# get number of mini batches
num_mini_batches = self._get_number_of_mini_batches(v)
x_total = np.copy(v[0])
y_total = np.copy(v[1])
p = np.random.permutation(len(x_total))
x_total = x_total[p]
y_total = y_total[p]
### END NEW CODE ###
for prev, curr in enumerate(range(1, num_mini_batches+1)):
x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
if x is not None: # pragma: no branch
yield (x, y)
if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
Edited by Ghost User