Skip to content

Permutate data for minibatches

Duplicate of #47 (closed)

Data of one station should be permutated before creating mini-batches. Each minibatch should in best practice be a representative sample out of the full data distribution. None permutated data would lead to "seasonal batches" for example only consisting of DJF or MAM...

  • create new index ordering by using np.random.permutation
  • permutate inputs (x_total) using new indices
  • permutate labels (y_total) using the same indices like for x_total

Permutation should happen directly before mini-batch loop in distribute_on_batches (data_distributor.py).

The method could look like this:

def distribute_on_batches(self, fit_call=True):
        while True:
            for k, v in enumerate(self.generator):
                # get rank of output
                mod_rank = self._get_model_rank()
                # get number of mini batches
                num_mini_batches = self._get_number_of_mini_batches(v)
                x_total = np.copy(v[0])
                y_total = np.copy(v[1])
                ### BEGIN NEW CODE ###
                p = np.random.permutation(len(x_total))
                x_total =  x_total[p]
                y_total =  y_total[p]
                ### END NEW CODE ###
                for prev, curr in enumerate(range(1, num_mini_batches+1)):
                    x = x_total[prev*self.batch_size:curr*self.batch_size, ...]
                    y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)]
                    if x is not None:  # pragma: no branch
                        yield (x, y)
                        if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call:
                            return
Edited by Ghost User