I am working on stereo vision task, and I need to load a pair of picture at a time. But the vision.transform behave differently on two pictures. For example, RandomCrop get different range. Is there any easy way the apply the same transform on a pair of picture?
You could use the functional API from torchvision.transforms
:
import torchvision.transforms.functional as TF
class YourDataset(Dataset):
def __init__(self):
self.image_left_paths = ...
self.image_right_paths = ...
def __getitem(self, index):
image_left = # load image with index from self.left_image_paths
image_right = # load image with index from self.right_image_paths
# Resize
resize = transforms.Resize(size=(520, 520))
image_left = resize(image_left)
image_right = resize(image_right)
# Random crop
i, j, h, w = transforms.RandomCrop.get_params(
image_left, output_size=(512, 512))
image_left = TF.crop(image_left, i, j, h, w)
image_right = TF.crop(image_right, i, j, h, w)
# Random horizontal flipping
if random.random() > 0.5:
image_left = TF.hflip(image_left)
image_right = TF.hflip(image_right)
# Random vertical flipping
if random.random() > 0.5:
image_left = TF.vflip(image_left)
image_right = TF.vflip(image_right)
image_left = TF.to_tensor(image_left)
image_right = TF.to_tensor(image_right)
return image_left, image_right
def __len__(self)
return len(self.image_left_paths)
14 Likes
It is really helpful. Thank you very much.
Is get_params documented ?
1 Like
I think I have a simple solution:
If the images are concatenated, the transformations are applied to all of them identically:
import torch
import torchvision.transforms as T
# Create two fake images (identical for test purposes):
image = torch.randn((3, 128, 128))
target = image.clone()
# This is the trick (concatenate the images):
both_images = torch.cat((image.unsqueeze(0), target.unsqueeze(0)),0)
# Apply the transformations to both images simultaneously:
transformed_images = T.RandomRotation(180)(both_images)
# Get the transformed images:
image_trans = transformed_images[0]
target_trans = transformed_images[1]
# Compare the transformed images:
torch.all(image_trans == target_trans).item()
>> True