Select Git revision
varlength_arrow.py
varlength_arrow.py 3.93 KiB
import argparse
import io
import os
from PIL import Image
import pyarrow as pa
import torch
import torchvision
class VarlengthArrow(torch.utils.data.Dataset):
def __init__(self, data_root, transform=None, use_mmap=False):
self.data_root = data_root
self.transform = transform
self.use_mmap = use_mmap
self.arrowfile = None
self.reader = None
if self.use_mmap:
arrowfile = pa.memory_map(self.data_root, 'r')
else:
arrowfile = pa.OSFile(self.data_root, 'rb')
try:
with pa.ipc.open_file(arrowfile) as reader:
self._len = reader.num_record_batches
finally:
arrowfile.close()
def __len__(self):
return self._len
def __getitem__(self, idx):
# Allow negative indexing (i.e., from the end).
if idx < 0:
idx = self._len + idx
if idx < 0 or idx >= self._len:
raise IndexError(f'index {idx} out of bounds')
if self.arrowfile is None:
if self.use_mmap:
self.arrowfile = pa.memory_map(self.data_root, 'r')
else:
self.arrowfile = pa.OSFile(self.data_root, 'rb')
self.reader = pa.ipc.open_file(self.arrowfile)
row = self.reader.get_batch(idx)
img_bytes = row['image_data'][0].as_py()
target = row['label'][0].as_py()
with io.BytesIO(img_bytes) as byte_stream:
with Image.open(byte_stream) as img:
img = img.convert("RGB")
if self.transform:
img = self.transform(img)
return img, target
def __del__(self):
if self.reader is not None and hasattr(self.reader, 'close'):
self.reader.close()
if self.arrowfile is not None:
self.arrowfile.close()
super_obj = super()
if hasattr(super_obj, '__del__'):
super_obj.__del__()
def _build_fake_varlength_dataset():
# Create a dataset consisting of differently sized images.
dset = torch.utils.data.ConcatDataset([
torchvision.datasets.FakeData(334, image_size=(1, 16, 16)),
torchvision.datasets.FakeData(333, image_size=(3, 244, 244)),
torchvision.datasets.FakeData(333, image_size=(3, 256, 256)),
])
return dset
def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'save_dir',
help='Where to save the resulting Apache Arrow files.',
)
return parser.parse_args()
def main():
args = _parse_args()
splits = ["train", "val"]
# Create the type singletons that our data will use.
binary_t = pa.binary()
uint16_t = pa.uint16()
# How will our dataset be structured? In this case, each element in
# the dataset will be a dictionary containing binary data under the
# key "image_data" and an unsigned 16-bit integer under the key
# "label".
schema = pa.schema([
pa.field('image_data', binary_t),
pa.field('label', uint16_t),
])
os.makedirs(args.save_dir, exist_ok=True)
for split in splits:
dset = _build_fake_varlength_dataset()
sampler = torch.utils.data.RandomSampler(dset)
with pa.OSFile(
os.path.join(args.save_dir, f'varlength-data-{split}.arrow'),
'wb',
) as f:
with pa.ipc.new_file(f, schema) as writer:
for i in sampler:
(pil_img, label) = dset[i]
with io.BytesIO() as bytes_buffer:
pil_img.save(bytes_buffer, 'png')
bytes_buffer.seek(0)
img_bytes = bytes_buffer.read()
image_data = pa.array([img_bytes], type=binary_t)
label = pa.array([label], type=uint16_t)
batch = pa.record_batch([image_data, label], schema=schema)
writer.write(batch)
if __name__ == '__main__':
main()