Skip to content
Snippets Groups Projects
Commit db657f2d authored by Jan Ebert's avatar Jan Ebert
Browse files

Initial commit

parents
Branches
Tags
No related merge requests found
# Apache Arrow for variable-length data
In this guide, we present an example for storing any type of data
using the efficient and easy-to-use file format [Apache
Arrow](https://arrow.apache.org) with Python. We will use fake image
classification data with different image sizes to create and read from
an Apache Arrow file.
Especially the API for writing and reading variable-length byte data
in Apache Arrow is – while taking some getting-used-to – a great
trade-off between usability and efficiency. We can use Apache Arrow
from Python using [PyArrow](https://arrow.apache.org/docs/python/).
While the principles of this guide will work with _any_ type of data
due to being based on variable-length byte sequences, data reading and
especially random access data reading will greatly speed up if you can
specify structured data. Please check the [PyArrow
documentation](https://arrow.apache.org/docs/python/) or [PyArrow
cookbook](https://arrow.apache.org/cookbook/py/index.html) to find out
more.
## Setup
There are two options for setting up the required packages for this guide:
1. Either create a new Python `venv` using the included
`requirements.txt`, like:
```shell
python3 -m venv ./env
source env/bin/activate
python -m pip install -U pip
python -m pip install -r requirements.txt
```
2. Or instead, use the provided modules from the module system. Note
that the most recent software stage does not have the Arrow module,
so we use an older software stage for this guide:
```shell
source modules.sh
```
## Data creation
To create a train and validation split of fake data in the directory
`./data`, execute the following:
```shell
source env/bin/activate
python varlength_arrow.py ./data
```
This will create two files of roughly 120 MB each, each containing 1000
fake image samples.
## Data reading
To use the example PyTorch
[`Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)
implementation for reading, try the following:
```shell
source env/bin/activate
python
>>> import varlength_arrow as va
>>> dset = va.VarlengthArrow('./data/varlength-data-train.arrow')
>>> # Now you can use `dset` like a normal image classification dataset.
>>> pil_img, label = next(iter(dset))
>>> for (pil_img, label) in dset:
... print(f'{pil_img.size = }, {label = }')
```
#!/usr/bin/env sh
module load Stages/2023
module load GCC OpenMPI PyTorch torchvision Arrow
import argparse
import io
import os
from PIL import Image
import pyarrow as pa
import torch
import torchvision
class VarlengthArrow(torch.utils.data.Dataset):
def __init__(self, data_root, transform=None, use_mmap=False):
self.data_root = data_root
self.transform = transform
self.use_mmap = use_mmap
self.arrowfile = None
self.reader = None
with pa.OSFile(self.data_root, 'rb') as f:
with pa.ipc.open_file(f) as reader:
self._len = reader.num_record_batches
def __len__(self):
return self._len
def __getitem__(self, idx):
if self.arrowfile is None:
if self.use_mmap:
self.arrowfile = pa.memory_map(self.data_root, 'r')
else:
self.arrowfile = pa.OSFile(self.data_root, 'rb')
self.reader = pa.ipc.open_file(self.arrowfile)
row = self.reader.get_batch(idx)
img_bytes = row['image_data'][0].as_py()
target = row['label'][0].as_py()
with io.BytesIO(img_bytes) as byte_stream:
with Image.open(byte_stream) as img:
img = img.convert("RGB")
if self.transform:
img = self.transform(img)
return img, target
def _build_fake_varlength_dataset():
# Create a dataset consisting of differently sized images.
dset = torch.utils.data.ConcatDataset([
torchvision.datasets.FakeData(334, image_size=(1, 16, 16)),
torchvision.datasets.FakeData(333, image_size=(3, 244, 244)),
torchvision.datasets.FakeData(333, image_size=(3, 256, 256)),
])
return dset
def _parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'save_dir',
help='Where to save the resulting Apache Arrow files.',
)
return parser.parse_args()
def main():
args = _parse_args()
splits = ["train", "val"]
# Create the type singletons that our data will use.
binary_t = pa.binary()
uint16_t = pa.uint16()
# How will our dataset be structured? In this case, each element in
# the dataset will be a dictionary containing binary data under the
# key "image_data" and an unsigned 16-bit integer under the key
# "label".
schema = pa.schema([
pa.field('image_data', binary_t),
pa.field('label', uint16_t),
])
os.makedirs(args.save_dir, exist_ok=True)
for split in splits:
dset = _build_fake_varlength_dataset()
sampler = torch.utils.data.RandomSampler(dset)
with pa.OSFile(
os.path.join(args.save_dir, f'varlength-data-{split}.arrow'),
'wb',
) as f:
with pa.ipc.new_file(f, schema) as writer:
for i in sampler:
(pil_img, label) = dset[i]
with io.BytesIO() as bytes_buffer:
pil_img.save(bytes_buffer, 'png')
bytes_buffer.seek(0)
img_bytes = bytes_buffer.read()
image_data = pa.array([img_bytes], type=binary_t)
label = pa.array([label], type=uint16_t)
batch = pa.record_batch([image_data, label], schema=schema)
writer.write(batch)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment