I know this thread is a bit old now, but the core issue op hit back then is still the same today: the CPU decode/resize step is usually the real bottleneck. torchvision.transforms.v2 helped a lot by moving many ops to GPU (color jitter, normalize, etc.), but there’s still overhead because each op ends up being a separate GPU kernel launch.
If anyone here has already optimized the CPU side (TurboJPEG, WebDataset, DALI, or just pre-resizing offline as op did), there’s still a bit more speed you can squeeze out after the resize step.
I’ve been working on a small Triton-based library that fuses common torchvision.transforms.v2 image augmentation ops — crop, flip, brightness/contrast/saturation, grayscale, normalize — into one single gpu kernel:
It doesn’t fix decode/resize, but once your batch is already on the GPU, the fused kernel is usually 5–12× faster than torchvision v2’s separate kernels, especially for larger images.
Super easy drop-in:
augment = ta.TritonFusedAugment(
crop_size=224,
horizontal_flip_p=0.5,
brightness=0.2, contrast=0.2, saturation=0.2,
mean=(0.4914, 0.4822, 0.4465),
std=(0.2470, 0.2435, 0.2616),
same_on_batch=False # Each image gets different random params (default)
)
for images, labels in train_loader:
images, labels = images.cuda(), labels.cuda()
images = augment(images) # All ops in 1 kernel per batch! 🚀
...
Might help anyone who already solved the CPU bottleneck but still sees GPU-side augmentation show up in the profiler.