How to apply same transform on a pair of picture?

I am working on stereo vision task, and I need to load a pair of picture at a time. But the vision.transform behave differently on two pictures. For example, RandomCrop get different range. Is there any easy way the apply the same transform on a pair of picture?

You could use the functional API from torchvision.transforms:

import torchvision.transforms.functional as TF

class YourDataset(Dataset):
    def __init__(self):
        self.image_left_paths = ...
        self.image_right_paths = ...

    def __getitem(self, index):
        image_left = # load image with index from self.left_image_paths
        image_right = # load image with index from self.right_image_paths
        # Resize
        resize = transforms.Resize(size=(520, 520))
        image_left = resize(image_left)
        image_right = resize(image_right)

        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(
            image_left, output_size=(512, 512))
        image_left = TF.crop(image_left, i, j, h, w)
        image_right = TF.crop(image_right, i, j, h, w)

        # Random horizontal flipping
        if random.random() > 0.5:
            image_left = TF.hflip(image_left)
            image_right = TF.hflip(image_right)

        # Random vertical flipping
        if random.random() > 0.5:
            image_left = TF.vflip(image_left)
            image_right = TF.vflip(image_right)

    image_left = TF.to_tensor(image_left)
    image_right = TF.to_tensor(image_right)
    return image_left, image_right

def __len__(self)
    return len(self.image_left_paths)
13 Likes

It is really helpful. Thank you very much.

Is get_params documented ?

1 Like

I think I have a simple solution:
If the images are concatenated, the transformations are applied to all of them identically:

import torch
import torchvision.transforms as T

# Create two fake images (identical for test purposes):
image = torch.randn((3, 128, 128))
target = image.clone()

# This is the trick (concatenate the images):
both_images = torch.cat((image.unsqueeze(0), target.unsqueeze(0)),0)

# Apply the transformations to both images simultaneously:
transformed_images = T.RandomRotation(180)(both_images)

# Get the transformed images:
image_trans = transformed_images[0]
target_trans = transformed_images[1]

# Compare the transformed images:
torch.all(image_trans == target_trans).item()

>> True