Hello,
I’m working on a dataset class in PyTorch where I need to load several 4K images in the __getitem__
method. Each time, I load 7 images, including a ground truth image (GT). Currently, I’m using the following approach:
def __getitem__(self, index):
imgs = self.files[index]
imgs["files"].append(imgs["GT"])
# Synchronously loading images one by one
imgs = [Image.open(img) for img in imgs["files"]]
# Transform images to tensors
images = torch.stack(self.to_tensor(self.dataset_config.tfms(imgs)))
However, loading these images, especially at 4K resolution, is quite slow and is significantly bottlenecking my training pipeline.
Additionally, the to_tensor
function I’m using for transformations also seems to be slow, which further impacts performance.
self.to_tensor: Callable[..., torch.Tensor] = TF.Compose(
[
TF.ToImage(),
TF.ToDtype(torch.float32, scale=True),
]
)
Transformations I’m applying:
- RandomShortestSize
- RandomCrop
- RandomHorizontalFlip
- RandomVerticalFlip
I’ve considered using asyncio
to fetch images in parallel, but seems to be slower (?).
My question:
- How can I efficiently parallelize the loading of multiple images (7 images per sample) in PyTorch’s
__getitem__
? - I’m open to using threading, multiprocessing, or any other solution that can speed up.
- Also the function to_tensor seems to be slow
Any suggestions would be greatly appreciated!
Thanks!