From a1cfdf1d42cb42dce571f9acf3caad0284edd0ee Mon Sep 17 00:00:00 2001 From: lukas leufen <l.leufen@fz-juelich.de> Date: Tue, 7 Jul 2020 16:53:03 +0200 Subject: [PATCH] implementation of StandardIterator, DataCollection, and KerasIterator (including batch preparation) --- src/data_handling/iterator.py | 146 ++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 src/data_handling/iterator.py diff --git a/src/data_handling/iterator.py b/src/data_handling/iterator.py new file mode 100644 index 00000000..4cfa459a --- /dev/null +++ b/src/data_handling/iterator.py @@ -0,0 +1,146 @@ + +from collections import Iterator, Iterable +import keras +import numpy as np +import math +import os +import shutil +import pickle + + +class StandardIterator(Iterator): + + _position: int = None + + def __init__(self, collection): + self._collection = collection + self._position = 0 + + def __next__(self): + try: + value = self._collection[self._position] + self._position += 1 + except IndexError: + raise StopIteration() + return value + + +class DataCollection(Iterable): + + def __init__(self, collection): + self._collection = collection + + def __iter__(self): + return StandardIterator(self._collection) + + +class KerasIterator(keras.utils.Sequence): + + def __init__(self, collection, batch_size, path, shuffle=False): + self._collection: DataCollection = collection + self._path = os.path.join(path, "%i.pickle") + self.batch_size = batch_size + self.shuffle = shuffle + self.indexes = [] + self._cleanup_path(path) + self._prepare_batches() + + def __len__(self): + return len(self.indexes) + + def __getitem__(self, index): + return self.__data_generation(self.indexes[index]) + + def __data_generation(self, index): + file = self._path % index + with open(file, "rb") as f: + data = pickle.load(f) + return data["X"], data["Y"] + + @staticmethod + def _concatenate(new, old): + return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new)) + + def _get_batch(self, data_list, b): + return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list)) + + def _prepare_batches(self): + index = 0 + remaining = None + for data in self._collection: + X, Y = data.get_X(), data.get_Y() + if remaining is not None: + # X = np.concatenate((remaining[0], X), axis=0) + # Y = np.concatenate((remaining[1], Y), axis=0) + X = self._concatenate(X, remaining[0]) + Y = self._concatenate(Y, remaining[1]) + # check shape + length = X[0].shape[0] + batches = self._get_number_of_mini_batches(length) + for b in range(batches): + # batch_X = X[b * self.batch_size:(b+1) * self.batch_size, ...] + # batch_Y = Y[b * self.batch_size:(b+1) * self.batch_size, ...] + batch_X = self._get_batch(X, b) + batch_Y = self._get_batch(Y, b) + self._save_to_pickle(X=batch_X, Y=batch_Y, index=index) + index += 1 + if (batches * self.batch_size) < length: + remaining = (self._get_batch(X, batches), self._get_batch(Y, batches)) + else: + remaining = None + if remaining is not None: + self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index) + index += 1 + self.indexes = np.arange(0, index).tolist() + + def _save_to_pickle(self, X, Y, index): + data = {"X": X, "Y": Y} + file = self._path % index + with open(file, "wb") as f: + pickle.dump(data, f) + + def _get_number_of_mini_batches(self, number_of_samples): + return math.floor(number_of_samples / self.batch_size) + + @staticmethod + def _cleanup_path(path, create_new=True): + if os.path.exists(path): + shutil.rmtree(path) + if create_new is True: + os.makedirs(path) + + def on_epoch_end(self): + if self.shuffle is True: + np.random.shuffle(self.indexes) + + +class DummyData: + + def __init__(self): + self.number_of_samples = np.random.randint(100, 150) + + def get_X(self): + X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables + X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables + X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables + return [X1, X2, X3] + + def get_Y(self): + Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables + Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables + return [Y1, Y2] + + +if __name__ == "__main__": + + collection = [] + for _ in range(3): + collection.append(DummyData()) + + data_collection = DataCollection(collection=collection) + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") + iterator = KerasIterator(data_collection, 1000, path, shuffle=True) + + for data in data_collection: + print(data) \ No newline at end of file -- GitLab