Optimizing the data augmentation process

After seeing some libraries being proposed to optimize the data loading / pre-processing phases in training (e.g., FFCV), I have been trying to see if this is possible in native PyTorch, particularly the data augmentation as this seems to be the largest bottleneck.

My current state is to have some transforms being performed in the __getitem__ function of my dataset object such as resizing and random / center cropping (i.e., getting all images to be the same size), then performing other transforms on the batch of images (e.g., RandAugment, horizontal flipping, etc.) and utilizing CUDA. Small code snippet on this:

import torch
import torch.nn as nn

import torchvision
from torchvision.transforms import v2

initial_transform = nn.Sequential([
    v2.PILToTensor(),
    v2.Resize(224),
    v2.RandomCrop(224)
])

transform = nn.Sequential([
    v2.RandAugment(),
    v2.RandomHorizontalFlipping(),
    v2.ToDtype(torch.float32, scale=True)
])
transform = transform.cuda()

dataset = MyDataset(..., transform=initial_transform)
dataloader = torch.utils.data.DataLoader(dataset, ...)

for images, labels in dataloader:
    images = images.cuda()
    images = transform(images) # <<<<<<<<<<<<<<<<
    
# Anything else after this

My issue is that for for the images = transform(images) line above, it applies the same exact transformation parameters to all images in the batch (i.e., all images get flipped horizontally or none get flipped), where I would like each image to be treated independent so it is the exact same operation as doing all transforms in __getitem__.

I know I could insert the line images = torch.stack([transform(img) for img in images]) instead of images = transform(images), but I feel that would be back to square one in terms of having data augmentations take a substantial amount of time.

Does anyone have any suggestions for this? I like that the v2 transforms can take batches of images now and can be put on the GPU, but it’d be even nicer if each transform was independent per image.

I will note that I haven’t done any timing comparisons yet, so there is some chance that the majority of the computation from data augmentation is in the PIL-to-tensor conversion or resize operation and my current efforts will result in negligible differences, but just thought I’d ask so I can test anyways.

Hi, take a look at the Kornia package. It can perform the same transform but with different parameters for each image in a batch. It depends on PyTorch and can run on both CPU and GPU. I actually like it more than torchvision.

Unless the augmentation operations you use have efficient CUDA implementations, I doubt you’ll see any measurable performance improvement. A custom kernel either running on the CPU or GPU that aggregates all augmentations is much faster, but also much less flexible than the standard approach.

From this performance evaluation on the torchvision GitHub, it seems like a good amount of the transforms should be much faster when done on GPU (e.g., Resize, RandAugment, etc.). If I had to guess they were doing that evaluation by passing a batch of images to the nn.Sequential object containing the transformations, which will apply the same transformation parameters to each example in the batch (which is not the desired outcome).

For the most part I’ve accepted my loss here in that I won’t be able to get the performance gains and per-example augmentations that I want. However, I will be trying to use torch.jit to see if that results in any gains (as doing jit with numba seems to be the big improvement in FFCV).