Skip to content
Snippets Groups Projects
Select Git revision
  • 352dedabaa7f1d3cf3411b28e7b8a32b64437fcb
  • master default
  • TEMPORARY_MERGE_PR_576
  • TEMPORARY_MERGE_PR_571
  • TEMPORARY_MERGE_PR_544
  • TEMPORARY_MERGE_PR_563
  • pancetta-patch-8
  • fieldsIO
  • neuralpint
  • TEMPORARY_MERGE_PR_525
  • vtk
  • pancetta-patch-7
  • TEMPORARY_MERGE_PR_502
  • TEMPORARY_MERGE_PR_501
  • TEMPORARY_MERGE_PR_498
  • v6
  • new-release
  • TEMPORARY_MERGE_PR_476
  • TEMPORARY_MERGE_PR_480
  • TEMPORARY_MERGE_PR_478
  • fix_JF
  • v5.6
  • v5.5.3
  • v5.5.2
  • v5.5.1
  • 5.5.0
26 results

TransferFenicsMesh.py

Blame
  • varlength_arrow.py 3.93 KiB
    import argparse
    import io
    import os
    
    from PIL import Image
    import pyarrow as pa
    import torch
    import torchvision
    
    
    class VarlengthArrow(torch.utils.data.Dataset):
        def __init__(self, data_root, transform=None, use_mmap=False):
            self.data_root = data_root
            self.transform = transform
            self.use_mmap = use_mmap
    
            self.arrowfile = None
            self.reader = None
    
            if self.use_mmap:
                arrowfile = pa.memory_map(self.data_root, 'r')
            else:
                arrowfile = pa.OSFile(self.data_root, 'rb')
            try:
                with pa.ipc.open_file(arrowfile) as reader:
                    self._len = reader.num_record_batches
            finally:
                arrowfile.close()
    
        def __len__(self):
            return self._len
    
        def __getitem__(self, idx):
            # Allow negative indexing (i.e., from the end).
            if idx < 0:
                idx = self._len + idx
    
            if idx < 0 or idx >= self._len:
                raise IndexError(f'index {idx} out of bounds')
    
            if self.arrowfile is None:
                if self.use_mmap:
                    self.arrowfile = pa.memory_map(self.data_root, 'r')
                else:
                    self.arrowfile = pa.OSFile(self.data_root, 'rb')
                self.reader = pa.ipc.open_file(self.arrowfile)
    
            row = self.reader.get_batch(idx)
            img_bytes = row['image_data'][0].as_py()
            target = row['label'][0].as_py()
    
            with io.BytesIO(img_bytes) as byte_stream:
                with Image.open(byte_stream) as img:
                    img = img.convert("RGB")
    
            if self.transform:
                img = self.transform(img)
    
            return img, target
    
        def __del__(self):
            if self.reader is not None and hasattr(self.reader, 'close'):
                self.reader.close()
            if self.arrowfile is not None:
                self.arrowfile.close()
    
            super_obj = super()
            if hasattr(super_obj, '__del__'):
                super_obj.__del__()
    
    
    def _build_fake_varlength_dataset():
        # Create a dataset consisting of differently sized images.
        dset = torch.utils.data.ConcatDataset([
            torchvision.datasets.FakeData(334, image_size=(1, 16, 16)),
            torchvision.datasets.FakeData(333, image_size=(3, 244, 244)),
            torchvision.datasets.FakeData(333, image_size=(3, 256, 256)),
        ])
        return dset
    
    
    def _parse_args():
        parser = argparse.ArgumentParser()
    
        parser.add_argument(
            'save_dir',
            help='Where to save the resulting Apache Arrow files.',
        )
    
        return parser.parse_args()
    
    
    def main():
        args = _parse_args()
    
        splits = ["train", "val"]
    
        # Create the type singletons that our data will use.
        binary_t = pa.binary()
        uint16_t = pa.uint16()
    
        # How will our dataset be structured? In this case, each element in
        # the dataset will be a dictionary containing binary data under the
        # key "image_data" and an unsigned 16-bit integer under the key
        # "label".
        schema = pa.schema([
            pa.field('image_data', binary_t),
            pa.field('label', uint16_t),
        ])
    
        os.makedirs(args.save_dir, exist_ok=True)
        for split in splits:
            dset = _build_fake_varlength_dataset()
            sampler = torch.utils.data.RandomSampler(dset)
    
            with pa.OSFile(
                    os.path.join(args.save_dir, f'varlength-data-{split}.arrow'),
                    'wb',
            ) as f:
                with pa.ipc.new_file(f, schema) as writer:
                    for i in sampler:
                        (pil_img, label) = dset[i]
                        with io.BytesIO() as bytes_buffer:
                            pil_img.save(bytes_buffer, 'png')
    
                            bytes_buffer.seek(0)
                            img_bytes = bytes_buffer.read()
    
                        image_data = pa.array([img_bytes], type=binary_t)
                        label = pa.array([label], type=uint16_t)
    
                        batch = pa.record_batch([image_data, label], schema=schema)
    
                        writer.write(batch)
    
    
    if __name__ == '__main__':
        main()