Skip to content
Snippets Groups Projects
Commit a1cfdf1d authored by lukas leufen's avatar lukas leufen
Browse files

implementation of StandardIterator, DataCollection, and KerasIterator (including batch preparation)

parent 786c530e
No related branches found
No related tags found
4 merge requests!136update release branch,!135Release v0.11.0,!134MLAir is decoupled from join,!117Resolve "Implement Iterator"
Pipeline #40602 passed
from collections import Iterator, Iterable
import keras
import numpy as np
import math
import os
import shutil
import pickle
class StandardIterator(Iterator):
_position: int = None
def __init__(self, collection):
self._collection = collection
self._position = 0
def __next__(self):
try:
value = self._collection[self._position]
self._position += 1
except IndexError:
raise StopIteration()
return value
class DataCollection(Iterable):
def __init__(self, collection):
self._collection = collection
def __iter__(self):
return StandardIterator(self._collection)
class KerasIterator(keras.utils.Sequence):
def __init__(self, collection, batch_size, path, shuffle=False):
self._collection: DataCollection = collection
self._path = os.path.join(path, "%i.pickle")
self.batch_size = batch_size
self.shuffle = shuffle
self.indexes = []
self._cleanup_path(path)
self._prepare_batches()
def __len__(self):
return len(self.indexes)
def __getitem__(self, index):
return self.__data_generation(self.indexes[index])
def __data_generation(self, index):
file = self._path % index
with open(file, "rb") as f:
data = pickle.load(f)
return data["X"], data["Y"]
@staticmethod
def _concatenate(new, old):
return list(map(lambda n1, n2: np.concatenate((n1, n2), axis=0), old, new))
def _get_batch(self, data_list, b):
return list(map(lambda data: data[b * self.batch_size:(b+1) * self.batch_size, ...], data_list))
def _prepare_batches(self):
index = 0
remaining = None
for data in self._collection:
X, Y = data.get_X(), data.get_Y()
if remaining is not None:
# X = np.concatenate((remaining[0], X), axis=0)
# Y = np.concatenate((remaining[1], Y), axis=0)
X = self._concatenate(X, remaining[0])
Y = self._concatenate(Y, remaining[1])
# check shape
length = X[0].shape[0]
batches = self._get_number_of_mini_batches(length)
for b in range(batches):
# batch_X = X[b * self.batch_size:(b+1) * self.batch_size, ...]
# batch_Y = Y[b * self.batch_size:(b+1) * self.batch_size, ...]
batch_X = self._get_batch(X, b)
batch_Y = self._get_batch(Y, b)
self._save_to_pickle(X=batch_X, Y=batch_Y, index=index)
index += 1
if (batches * self.batch_size) < length:
remaining = (self._get_batch(X, batches), self._get_batch(Y, batches))
else:
remaining = None
if remaining is not None:
self._save_to_pickle(X=remaining[0], Y=remaining[1], index=index)
index += 1
self.indexes = np.arange(0, index).tolist()
def _save_to_pickle(self, X, Y, index):
data = {"X": X, "Y": Y}
file = self._path % index
with open(file, "wb") as f:
pickle.dump(data, f)
def _get_number_of_mini_batches(self, number_of_samples):
return math.floor(number_of_samples / self.batch_size)
@staticmethod
def _cleanup_path(path, create_new=True):
if os.path.exists(path):
shutil.rmtree(path)
if create_new is True:
os.makedirs(path)
def on_epoch_end(self):
if self.shuffle is True:
np.random.shuffle(self.indexes)
class DummyData:
def __init__(self):
self.number_of_samples = np.random.randint(100, 150)
def get_X(self):
X1 = np.random.randint(0, 10, size=(self.number_of_samples, 14, 5)) # samples, window, variables
X2 = np.random.randint(21, 30, size=(self.number_of_samples, 10, 2)) # samples, window, variables
X3 = np.random.randint(-5, 0, size=(self.number_of_samples, 1, 2)) # samples, window, variables
return [X1, X2, X3]
def get_Y(self):
Y1 = np.random.randint(0, 10, size=(self.number_of_samples, 5, 1)) # samples, window, variables
Y2 = np.random.randint(21, 30, size=(self.number_of_samples, 5, 1)) # samples, window, variables
return [Y1, Y2]
if __name__ == "__main__":
collection = []
for _ in range(3):
collection.append(DummyData())
data_collection = DataCollection(collection=collection)
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata")
iterator = KerasIterator(data_collection, 1000, path, shuffle=True)
for data in data_collection:
print(data)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment