diff --git a/src/data_handling/iterator.py b/src/data_handling/iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..51073286810f4eb6ef02beb47727932008cb9b7c --- /dev/null +++ b/src/data_handling/iterator.py @@ -0,0 +1,162 @@ + +__author__ = 'Lukas Leufen' +__date__ = '2020-07-07' + +from collections import Iterator, Iterable +import keras +import numpy as np +import math +import os +import shutil +import pickle +from typing import Tuple, List + + +class StandardIterator(Iterator): + + _position: int = None + + def __init__(self, collection: list): + assert isinstance(collection, list) + self._collection = collection + self._position = 0 + + def __next__(self): + """Return next element or stop iteration.""" + try: + value = self._collection[self._position] + self._position += 1 + except IndexError: + raise StopIteration() + return value + + +class DataCollection(Iterable): + + def __init__(self, collection: list): + assert isinstance(collection, list) + self._collection = collection + + def __iter__(self) -> Iterator: + return StandardIterator(self._collection) + + +class KerasIterator(keras.utils.Sequence): + + def __init__(self, collection: DataCollection, batch_size: int, path: str, shuffle: bool = False): + self._collection = collection + self._path = os.path.join(path, "%i.pickle") + self.batch_size = batch_size + self.shuffle = shuffle + self.indexes: list = [] + self._cleanup_path(path) + self._prepare_batches() + + def __len__(self) -> int: + return len(self.indexes) + + def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]: + """Get batch for given index.""" + return self.__data_generation(self.indexes[index]) + + def __data_generation(self, index: int) -> Tuple[np.ndarray, np.ndarray]: + """Load pickle data from disk.""" + file = self._path % index + with open(file, "rb") as f: + data = pickle.load(f) + return data["X"], data["Y"] + + @staticmethod + def _concatenate(new: List[np.ndarray], old: List[np.ndarray]) -> List[np.ndarray]: + """Concatenate two lists of data along axis=0.""" + return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new)) + + def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]: + """Get batch according to batch size from data list.""" + return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list)) + + def _prepare_batches(self) -> None: + """ + Prepare all batches as locally stored files. + + Walk through all elements of collection and split (or merge) data according to the batch size. Too long data + sets are divided into multiple batches. Not fully filled batches are merged with data from the next collection + element. If data is remaining after the last element, it is saved as smaller batch. All batches are enumerated + beginning from 0. A list with all batch numbers is stored in class's parameter indexes. + """ + index = 0 + remaining = None + for data in self._collection: + X, Y = data.get_X(), data.get_Y() + if remaining is not None: + X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1]) + length = X[0].shape[0] + batches = self._get_number_of_mini_batches(length) + for b in range(batches): + batch_X, batch_Y = self._get_batch(X, b), self._get_batch(Y, b) + self._save_to_pickle(X=batch_X, Y=batch_Y, index=index) + index += 1 + if (batches * self.batch_size) < length: # keep remaining to concatenate with next data element + remaining = (self._get_batch(X, batches), self._get_batch(Y, batches)) + else: + remaining = None + if remaining is not None: # add remaining as smaller batch + self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index) + index += 1 + self.indexes = np.arange(0, index).tolist() + + def _save_to_pickle(self, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None: + """Save data as pickle file with variables X and Y and given index as <index>.pickle .""" + data = {"X": X, "Y": Y} + file = self._path % index + with open(file, "wb") as f: + pickle.dump(data, f) + + def _get_number_of_mini_batches(self, number_of_samples: int) -> int: + """Return number of mini batches as the floored ration of number of samples to batch size.""" + return math.floor(number_of_samples / self.batch_size) + + @staticmethod + def _cleanup_path(path: str, create_new: bool = True) -> None: + """First remove existing path, second create empty path if enabled.""" + if os.path.exists(path): + shutil.rmtree(path) + if create_new is True: + os.makedirs(path) + + def on_epoch_end(self) -> None: + """Randomly shuffle indexes if enabled.""" + if self.shuffle is True: + np.random.shuffle(self.indexes) + + +class DummyData: # pragma: no cover + + def __init__(self, number_of_samples=np.random.randint(100, 150)): + self.number_of_samples = number_of_samples + + def get_X(self): + X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables + X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables + X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables + return [X1, X2, X3] + + def get_Y(self): + Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables + Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables + return [Y1, Y2] + + +if __name__ == "__main__": + + collection = [] + for _ in range(3): + collection.append(DummyData(50)) + + data_collection = DataCollection(collection=collection) + + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") + iterator = KerasIterator(data_collection, 25, path, shuffle=True) + + for data in data_collection: + print(data) \ No newline at end of file diff --git a/test/test_data_handling/test_iterator.py b/test/test_data_handling/test_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1cf683d627495cf958b6c2376a5c42a4c6e61f --- /dev/null +++ b/test/test_data_handling/test_iterator.py @@ -0,0 +1,201 @@ + +from src.data_handling.iterator import DataCollection, StandardIterator, KerasIterator +from src.helpers.testing import PyTestAllEqual + +import numpy as np +import pytest +import os +import shutil + + +class TestStandardIterator: + + @pytest.fixture + def collection(self): + return list(range(10)) + + def test_blank(self): + std_iterator = object.__new__(StandardIterator) + assert std_iterator._position is None + + def test_init(self, collection): + std_iterator = StandardIterator(collection) + assert std_iterator._collection == list(range(10)) + assert std_iterator._position == 0 + + def test_next(self, collection): + std_iterator = StandardIterator(collection) + for i in range(10): + assert i == next(std_iterator) + with pytest.raises(StopIteration): + next(std_iterator) + std_iterator = StandardIterator(collection) + for e, i in enumerate(iter(std_iterator)): + assert i == e + + +class TestDataCollection: + + @pytest.fixture + def collection(self): + return list(range(10)) + + def test_init(self, collection): + data_collection = DataCollection(collection) + assert data_collection._collection == collection + + def test_iter(self, collection): + data_collection = DataCollection(collection) + assert isinstance(iter(data_collection), StandardIterator) + for e, i in enumerate(data_collection): + assert i == e + + +class DummyData: + + def __init__(self, number_of_samples=np.random.randint(100, 150)): + self.number_of_samples = number_of_samples + + def get_X(self): + X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables + X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables + X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables + return [X1, X2, X3] + + def get_Y(self): + Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables + Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables + return [Y1, Y2] + + +class TestKerasIterator: + + @pytest.fixture + def collection(self): + coll = [] + for i in range(3): + coll.append(DummyData(50 + i)) + data_coll = DataCollection(collection=coll) + return data_coll + + @pytest.fixture + def path(self): + p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") + shutil.rmtree(p, ignore_errors=True) if os.path.exists(p) else None + yield p + shutil.rmtree(p, ignore_errors=True) + + def test_init(self, collection, path): + iterator = KerasIterator(collection, 25, path) + assert isinstance(iterator._collection, DataCollection) + assert iterator._path == os.path.join(path, "%i.pickle") + assert iterator.batch_size == 25 + assert iterator.shuffle is False + + def test_cleanup_path(self, path): + assert os.path.exists(path) is False + iterator = object.__new__(KerasIterator) + iterator._cleanup_path(path, create_new=False) + assert os.path.exists(path) is False + iterator._cleanup_path(path) + assert os.path.exists(path) is True + iterator._cleanup_path(path, create_new=False) + assert os.path.exists(path) is False + + def test_get_number_of_mini_batches(self): + iterator = object.__new__(KerasIterator) + iterator.batch_size = 36 + assert iterator._get_number_of_mini_batches(30) == 0 + assert iterator._get_number_of_mini_batches(40) == 1 + assert iterator._get_number_of_mini_batches(72) == 2 + + def test_len(self): + iterator = object.__new__(KerasIterator) + iterator.indexes = [0, 1, 2, 3, 4, 5] + assert len(iterator) == 6 + + def test_concatenate(self): + arr1 = DummyData(10).get_X() + arr2 = DummyData(50).get_X() + iterator = object.__new__(KerasIterator) + new_arr = iterator._concatenate(arr2, arr1) + test_arr = [np.concatenate((arr1[0], arr2[0]), axis=0), + np.concatenate((arr1[1], arr2[1]), axis=0), + np.concatenate((arr1[2], arr2[2]), axis=0)] + for i in range(3): + assert PyTestAllEqual([new_arr[i], test_arr[i]]) + + def test_get_batch(self): + arr = DummyData(20).get_X() + iterator = object.__new__(KerasIterator) + iterator.batch_size = 19 + batch1 = iterator._get_batch(arr, 0) + assert batch1[0].shape[0] == 19 + batch2 = iterator._get_batch(arr, 1) + assert batch2[0].shape[0] == 1 + + def test_save_to_pickle(self, path): + os.makedirs(path) + d = DummyData(20) + X, Y = d.get_X(), d.get_Y() + iterator = object.__new__(KerasIterator) + iterator._path = os.path.join(path, "%i.pickle") + assert os.path.exists(iterator._path % 2) is False + iterator._save_to_pickle(X=X, Y=Y, index=2) + assert os.path.exists(iterator._path % 2) is True + + def test_prepare_batches(self, collection, path): + iterator = object.__new__(KerasIterator) + iterator._collection = collection + iterator.batch_size = 50 + iterator.indexes = [] + iterator._path = os.path.join(path, "%i.pickle") + os.makedirs(path) + iterator._prepare_batches() + assert len(os.listdir(path)) == 4 + assert len(iterator.indexes) == 4 + assert len(iterator) == 4 + assert iterator.indexes == [0, 1, 2, 3] + + def test_prepare_batches_no_remaining(self, path): + iterator = object.__new__(KerasIterator) + iterator._collection = DataCollection([DummyData(50)]) + iterator.batch_size = 50 + iterator.indexes = [] + iterator._path = os.path.join(path, "%i.pickle") + os.makedirs(path) + iterator._prepare_batches() + assert len(os.listdir(path)) == 1 + assert len(iterator.indexes) == 1 + assert len(iterator) == 1 + assert iterator.indexes == [0] + + def test_data_generation(self, collection, path): + iterator = KerasIterator(collection, 50, path) + X, Y = iterator._KerasIterator__data_generation(0) + expected = next(iter(collection)) + assert PyTestAllEqual([X, expected.get_X()]) + assert PyTestAllEqual([Y, expected.get_Y()]) + + def test_getitem(self, collection, path): + iterator = KerasIterator(collection, 50, path) + X, Y = iterator[0] + expected = next(iter(collection)) + assert PyTestAllEqual([X, expected.get_X()]) + assert PyTestAllEqual([Y, expected.get_Y()]) + reversed(iterator.indexes) + X, Y = iterator[3] + assert PyTestAllEqual([X, expected.get_X()]) + assert PyTestAllEqual([Y, expected.get_Y()]) + + def test_on_epoch_end(self): + iterator = object.__new__(KerasIterator) + iterator.indexes = [0, 1, 2, 3, 4] + iterator.shuffle = False + iterator.on_epoch_end() + assert iterator.indexes == [0, 1, 2, 3, 4] + iterator.shuffle = True + while iterator.indexes == sorted(iterator.indexes): + iterator.on_epoch_end() + assert iterator.indexes != [0, 1, 2, 3, 4] + assert sorted(iterator.indexes) == [0, 1, 2, 3, 4] \ No newline at end of file