You could use the functional API from torchvision.transforms
:
import torchvision.transforms.functional as TF
class YourDataset(Dataset):
def __init__(self):
self.image_left_paths = ...
self.image_right_paths = ...
def __getitem(self, index):
image_left = # load image with index from self.left_image_paths
image_right = # load image with index from self.right_image_paths
# Resize
resize = transforms.Resize(size=(520, 520))
image_left = resize(image_left)
image_right = resize(image_right)
# Random crop
i, j, h, w = transforms.RandomCrop.get_params(
image_left, output_size=(512, 512))
image_left = TF.crop(image_left, i, j, h, w)
image_right = TF.crop(image_right, i, j, h, w)
# Random horizontal flipping
if random.random() > 0.5:
image_left = TF.hflip(image_left)
image_right = TF.hflip(image_right)
# Random vertical flipping
if random.random() > 0.5:
image_left = TF.vflip(image_left)
image_right = TF.vflip(image_right)
image_left = TF.to_tensor(image_left)
image_right = TF.to_tensor(image_right)
return image_left, image_right
def __len__(self)
return len(self.image_left_paths)