Torchvision same random transforms on multi images


I am working on an optical flow algorithm, where the input is 2 images of size HxWx3 and the target is a tensor of size HxWx2. To do data augmentation, I need to apply the same random transformation to all the 3 tensors. Take this augmentation for example:

        aug_transforms = transforms.Compose([
            transforms.RandomResizedCrop((614, 216), scale=(0.1, 1.0)),

The torchvision transforms accept PIL object as an input. To solve this problem I have two questions:

  • Can we save the random state of a given transform, so that we can apply the same randomness to img1, img2 and the flow ?
  • Can we combine all the date into a PIL object, apply transformation and retrieve back the same tensors values range we had at the beginning ?

I would recommend to use the functional API as shown here, since working with the seed might be hard and yield unwanted side effects.

That being said, I’m not sure at the moment, if PIL will accept the target images with two channels, so you might need to split the channels (or add a dummy channel to the target).

Thank you for the reply. I see that this issue is faced a lot in different projects that have a tensor output. Just our of curiosity, do you know why they are using PIL instead of just numpy arrays, so we can apply the transforms on a list of tensors with different number of channels ?

For people having the same issue. I found an interesting API that does data augmentation on numpy array of any channel number, can do it on multi targets, bounding boxes, masks …

import albumentations as albu
import cv2
co_aug_transforms = albu.Compose([albu.RandomSizedCrop((300, 436), 400, 940, w2h_ratio=1024/436, p=0.5),

flow = readflo(path_to_flow)
frame1 = cv2.cvtColor(cv2.imread(path_to_frame1), cv2.COLOR_BGR2RGB)
frame2 = cv2.cvtColor(cv2.imread(path_to_frame2), cv2.COLOR_BGR2RGB)
target = {}
target3 = {'image0': 'image', 'image1': 'image', 'image2': 'image'}
transformed = albu.Compose(co_aug_transforms, p=1, additional_targets=target3)(image=frame1,

frame1 = transformed['image']
frame2 = transformed['image0']
flow = transformed['image1']