Applying different data augmentation per image in a mini-batch

As far as I know, the random transformations (e.g. random crop, random resized crop, etc.) from torchvision.transforms module apply the same transformations to all the images of a given batch.
Is there any efficient way to apply different random transformations for each image in a given mini-batch?

Thanks in advance.

I think you would need to apply the random transformations on each sample and could use e.g. transforms.Lambda for it:

# setup
x = torch.zeros(3, 10, 10)
x[:, 3:7, 3:7] = 1.

# same transformation on each sample
transform = transforms.RandomCrop(size=5)
y1 = transform(x)
print(y1)

# apply transformation per sample
y2 = transforms.Lambda(lambda x: torch.stack([transform(x_) for x_ in x]))(x)
print(y2)
3 Likes

Thanks, your solution works well!