Test data streaming with multiprocessing
In this branch, the suggested approach of parallelized data streaming by Stefan Kesselheim is tested.
The approach is based on the package multiprocessing
which is used to draw samples in a thread pool.
The distributed sampling is envelopped by .batch(nworkers)
and .unbatch()
to efficiently spread the creation of the desired mini-batch.
The raw draft from Stefan's e-mail is as follows:
import tensorflow as tf
import time
import multiprocessing
import numpy as np
class MyTest():
def __init__(self, thread_pool_size=4):
self.a=[np.ones((2,2), dtype=np.float32) * i for i in range(1000)]
self.pool = multiprocessing.pool.ThreadPool(thread_pool_size)
def __len__(self):
return len(self.a)
def __getitem__(self, i):
#print("loading ", i)
time.sleep(0.02)
return self.a[i]
def getitems(self, indices):
return np.array(self.pool.map(self.__getitem__ ,indices))
t=MyTest(10)
%%timeit
print("")
a=t.getitems(list(range(100)))
tf_fun=lambda i: tf.numpy_function(t.__getitem__, [i] , tf.float32 )
%%timeit
ds=tf.data.Dataset.range(100).map(tf_fun)
for x in ds:
pass
tf_fun2=lambda i: tf.numpy_function(t.getitems, [i] , tf.float32 )
inp=tf.Variable(range(2,12))
tf_fun2(inp)
%%timeit
ds=tf.data.Dataset.range(100).batch(10).map(tf_fun2).unbatch()
for x in ds:
pass