After seeing some libraries being proposed to optimize the data loading / pre-processing phases in training (e.g., FFCV), I have been trying to see if this is possible in native PyTorch, particularly the data augmentation as this seems to be the largest bottleneck.
My current state is to have some transforms being performed in the __getitem__
function of my dataset object such as resizing and random / center cropping (i.e., getting all images to be the same size), then performing other transforms on the batch of images (e.g., RandAugment, horizontal flipping, etc.) and utilizing CUDA. Small code snippet on this:
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import v2
initial_transform = nn.Sequential([
v2.PILToTensor(),
v2.Resize(224),
v2.RandomCrop(224)
])
transform = nn.Sequential([
v2.RandAugment(),
v2.RandomHorizontalFlipping(),
v2.ToDtype(torch.float32, scale=True)
])
transform = transform.cuda()
dataset = MyDataset(..., transform=initial_transform)
dataloader = torch.utils.data.DataLoader(dataset, ...)
for images, labels in dataloader:
images = images.cuda()
images = transform(images) # <<<<<<<<<<<<<<<<
# Anything else after this
My issue is that for for the images = transform(images)
line above, it applies the same exact transformation parameters to all images in the batch (i.e., all images get flipped horizontally or none get flipped), where I would like each image to be treated independent so it is the exact same operation as doing all transforms in __getitem__
.
I know I could insert the line images = torch.stack([transform(img) for img in images])
instead of images = transform(images)
, but I feel that would be back to square one in terms of having data augmentations take a substantial amount of time.
Does anyone have any suggestions for this? I like that the v2 transforms can take batches of images now and can be put on the GPU, but it’d be even nicer if each transform was independent per image.
I will note that I haven’t done any timing comparisons yet, so there is some chance that the majority of the computation from data augmentation is in the PIL-to-tensor conversion or resize operation and my current efforts will result in negligible differences, but just thought I’d ask so I can test anyways.