From a1cfdf1d42cb42dce571f9acf3caad0284edd0ee Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Tue, 7 Jul 2020 16:53:03 +0200
Subject: [PATCH] implementation of StandardIterator, DataCollection, and
 KerasIterator (including batch preparation)

---
 src/data_handling/iterator.py | 146 ++++++++++++++++++++++++++++++++++
 1 file changed, 146 insertions(+)
 create mode 100644 src/data_handling/iterator.py

diff --git a/src/data_handling/iterator.py b/src/data_handling/iterator.py
new file mode 100644
index 00000000..4cfa459a
--- /dev/null
+++ b/src/data_handling/iterator.py
@@ -0,0 +1,146 @@
+
+from collections import Iterator, Iterable
+import keras
+import numpy as np
+import math
+import os
+import shutil
+import pickle
+
+
+class StandardIterator(Iterator):
+
+    _position: int = None
+
+    def __init__(self, collection):
+        self._collection = collection
+        self._position = 0
+
+    def __next__(self):
+        try:
+            value = self._collection[self._position]
+            self._position += 1
+        except IndexError:
+            raise StopIteration()
+        return value
+
+
+class DataCollection(Iterable):
+
+    def __init__(self, collection):
+        self._collection = collection
+
+    def __iter__(self):
+        return StandardIterator(self._collection)
+
+
+class KerasIterator(keras.utils.Sequence):
+
+    def __init__(self, collection, batch_size, path, shuffle=False):
+        self._collection: DataCollection = collection
+        self._path = os.path.join(path, "%i.pickle")
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+        self.indexes = []
+        self._cleanup_path(path)
+        self._prepare_batches()
+
+    def __len__(self):
+        return len(self.indexes)
+
+    def __getitem__(self, index):
+        return self.__data_generation(self.indexes[index])
+
+    def __data_generation(self, index):
+        file = self._path % index
+        with open(file, "rb") as f:
+            data = pickle.load(f)
+        return data["X"], data["Y"]
+
+    @staticmethod
+    def _concatenate(new, old):
+        return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new))
+
+    def _get_batch(self, data_list, b):
+        return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list))
+
+    def _prepare_batches(self):
+        index = 0
+        remaining = None
+        for data in self._collection:
+            X, Y = data.get_X(), data.get_Y()
+            if remaining is not None:
+                # X = np.concatenate((remaining[0], X), axis=0)
+                # Y = np.concatenate((remaining[1], Y), axis=0)
+                X = self._concatenate(X, remaining[0])
+                Y = self._concatenate(Y, remaining[1])
+            # check shape
+            length = X[0].shape[0]
+            batches = self._get_number_of_mini_batches(length)
+            for b in range(batches):
+                # batch_X = X[b * self.batch_size:(b+1) * self.batch_size, ...]
+                # batch_Y = Y[b * self.batch_size:(b+1) * self.batch_size, ...]
+                batch_X = self._get_batch(X, b)
+                batch_Y = self._get_batch(Y, b)
+                self._save_to_pickle(X=batch_X, Y=batch_Y, index=index)
+                index += 1
+            if (batches * self.batch_size) < length:
+                remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
+            else:
+                remaining = None
+        if remaining is not None:
+            self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
+            index += 1
+        self.indexes = np.arange(0, index).tolist()
+
+    def _save_to_pickle(self, X, Y, index):
+        data = {"X": X, "Y": Y}
+        file = self._path % index
+        with open(file, "wb") as f:
+            pickle.dump(data, f)
+
+    def _get_number_of_mini_batches(self, number_of_samples):
+        return math.floor(number_of_samples / self.batch_size)
+
+    @staticmethod
+    def _cleanup_path(path, create_new=True):
+        if os.path.exists(path):
+            shutil.rmtree(path)
+        if create_new is True:
+            os.makedirs(path)
+
+    def on_epoch_end(self):
+        if self.shuffle is True:
+            np.random.shuffle(self.indexes)
+
+
+class DummyData:
+
+    def __init__(self):
+        self.number_of_samples = np.random.randint(100, 150)
+
+    def get_X(self):
+        X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables
+        X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables
+        X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables
+        return [X1, X2, X3]
+
+    def get_Y(self):
+        Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables
+        Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables
+        return [Y1, Y2]
+
+
+if __name__ == "__main__":
+
+    collection = []
+    for _ in range(3):
+        collection.append(DummyData())
+
+    data_collection = DataCollection(collection=collection)
+
+    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
+    iterator = KerasIterator(data_collection, 1000, path, shuffle=True)
+
+    for data in data_collection:
+        print(data)
\ No newline at end of file
-- 
GitLab