I am working on stereo vision task, and I need to load a pair of picture at a time. But the vision.transform behave differently on two pictures. For example, RandomCrop get different range. Is there any easy way the apply the same transform on a pair of picture?
You could use the functional API from
import torchvision.transforms.functional as TF class YourDataset(Dataset): def __init__(self): self.image_left_paths = ... self.image_right_paths = ... def __getitem(self, index): image_left = # load image with index from self.left_image_paths image_right = # load image with index from self.right_image_paths # Resize resize = transforms.Resize(size=(520, 520)) image_left = resize(image_left) image_right = resize(image_right) # Random crop i, j, h, w = transforms.RandomCrop.get_params( image_left, output_size=(512, 512)) image_left = TF.crop(image_left, i, j, h, w) image_right = TF.crop(image_right, i, j, h, w) # Random horizontal flipping if random.random() > 0.5: image_left = TF.hflip(image_left) image_right = TF.hflip(image_right) # Random vertical flipping if random.random() > 0.5: image_left = TF.vflip(image_left) image_right = TF.vflip(image_right) image_left = TF.to_tensor(image_left) image_right = TF.to_tensor(image_right) return image_left, image_right def __len__(self) return len(self.image_left_paths)
It is really helpful. Thank you very much.
Is get_params documented ?
I think I have a simple solution:
If the images are concatenated, the transformations are applied to all of them identically:
import torch import torchvision.transforms as T # Create two fake images (identical for test purposes): image = torch.randn((3, 128, 128)) target = image.clone() # This is the trick (concatenate the images): both_images = torch.cat((image.unsqueeze(0), target.unsqueeze(0)),0) # Apply the transformations to both images simultaneously: transformed_images = T.RandomRotation(180)(both_images) # Get the transformed images: image_trans = transformed_images target_trans = transformed_images # Compare the transformed images: torch.all(image_trans == target_trans).item() >> True