Skip to content
Snippets Groups Projects

Resolve "Implement Iterator"

2 files
+ 363
0
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 162
0
 
 
__author__ = 'Lukas Leufen'
 
__date__ = '2020-07-07'
 
 
from collections import Iterator, Iterable
 
import keras
 
import numpy as np
 
import math
 
import os
 
import shutil
 
import pickle
 
from typing import Tuple, List
 
 
 
class StandardIterator(Iterator):
 
 
_position: int = None
 
 
def __init__(self, collection: list):
 
assert isinstance(collection, list)
 
self._collection = collection
 
self._position = 0
 
 
def __next__(self):
 
"""Return next element or stop iteration."""
 
try:
 
value = self._collection[self._position]
 
self._position += 1
 
except IndexError:
 
raise StopIteration()
 
return value
 
 
 
class DataCollection(Iterable):
 
 
def __init__(self, collection: list):
 
assert isinstance(collection, list)
 
self._collection = collection
 
 
def __iter__(self) -> Iterator:
 
return StandardIterator(self._collection)
 
 
 
class KerasIterator(keras.utils.Sequence):
 
 
def __init__(self, collection: DataCollection, batch_size: int, path: str, shuffle: bool = False):
 
self._collection = collection
 
self._path = os.path.join(path, "%i.pickle")
 
self.batch_size = batch_size
 
self.shuffle = shuffle
 
self.indexes: list = []
 
self._cleanup_path(path)
 
self._prepare_batches()
 
 
def __len__(self) -> int:
 
return len(self.indexes)
 
 
def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
 
"""Get batch for given index."""
 
return self.__data_generation(self.indexes[index])
 
 
def __data_generation(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
 
"""Load pickle data from disk."""
 
file = self._path % index
 
with open(file, "rb") as f:
 
data = pickle.load(f)
 
return data["X"], data["Y"]
 
 
@staticmethod
 
def _concatenate(new: List[np.ndarray], old: List[np.ndarray]) -> List[np.ndarray]:
 
"""Concatenate two lists of data along axis=0."""
 
return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new))
 
 
def _get_batch(self, data_list: List[np.ndarray], b: int) -> List[np.ndarray]:
 
"""Get batch according to batch size from data list."""
 
return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list))
 
 
def _prepare_batches(self) -> None:
 
"""
 
Prepare all batches as locally stored files.
 
 
Walk through all elements of collection and split (or merge) data according to the batch size. Too long data
 
sets are divided into multiple batches. Not fully filled batches are merged with data from the next collection
 
element. If data is remaining after the last element, it is saved as smaller batch. All batches are enumerated
 
beginning from 0. A list with all batch numbers is stored in class's parameter indexes.
 
"""
 
index = 0
 
remaining = None
 
for data in self._collection:
 
X, Y = data.get_X(), data.get_Y()
 
if remaining is not None:
 
X, Y = self._concatenate(X, remaining[0]), self._concatenate(Y, remaining[1])
 
length = X[0].shape[0]
 
batches = self._get_number_of_mini_batches(length)
 
for b in range(batches):
 
batch_X, batch_Y = self._get_batch(X, b), self._get_batch(Y, b)
 
self._save_to_pickle(X=batch_X, Y=batch_Y, index=index)
 
index += 1
 
if (batches * self.batch_size) < length: # keep remaining to concatenate with next data element
 
remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
 
else:
 
remaining = None
 
if remaining is not None: # add remaining as smaller batch
 
self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
 
index += 1
 
self.indexes = np.arange(0, index).tolist()
 
 
def _save_to_pickle(self, X: List[np.ndarray], Y: List[np.ndarray], index: int) -> None:
 
"""Save data as pickle file with variables X and Y and given index as <index>.pickle ."""
 
data = {"X": X, "Y": Y}
 
file = self._path % index
 
with open(file, "wb") as f:
 
pickle.dump(data, f)
 
 
def _get_number_of_mini_batches(self, number_of_samples: int) -> int:
 
"""Return number of mini batches as the floored ration of number of samples to batch size."""
 
return math.floor(number_of_samples / self.batch_size)
 
 
@staticmethod
 
def _cleanup_path(path: str, create_new: bool = True) -> None:
 
"""First remove existing path, second create empty path if enabled."""
 
if os.path.exists(path):
 
shutil.rmtree(path)
 
if create_new is True:
 
os.makedirs(path)
 
 
def on_epoch_end(self) -> None:
 
"""Randomly shuffle indexes if enabled."""
 
if self.shuffle is True:
 
np.random.shuffle(self.indexes)
 
 
 
class DummyData: # pragma: no cover
 
 
def __init__(self, number_of_samples=np.random.randint(100, 150)):
 
self.number_of_samples = number_of_samples
 
 
def get_X(self):
 
X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables
 
X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables
 
X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables
 
return [X1, X2, X3]
 
 
def get_Y(self):
 
Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables
 
Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables
 
return [Y1, Y2]
 
 
 
if __name__ == "__main__":
 
 
collection = []
 
for _ in range(3):
 
collection.append(DummyData(50))
 
 
data_collection = DataCollection(collection=collection)
 
 
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
 
iterator = KerasIterator(data_collection, 25, path, shuffle=True)
 
 
for data in data_collection:
 
print(data)
 
\ No newline at end of file
Loading