Alternatively to the functions from the tutorial, you could use torchvision’s functional API.
Here is a small example for an image and the corresponding mask image:
class MyDataset(Dataset):
def __init__(self, image_paths, target_paths, train=True):
self.image_paths = image_paths
self.target_paths = target_paths
def transform(self, image, mask):
# Resize
resize = transforms.Resize(size=(520, 520))
image = resize(image)
mask = resize(mask)
# Random crop
i, j, h, w = transforms.RandomCrop.get_params(
image, output_size=(512, 512))
image = TF.crop(image, i, j, h, w)
mask = TF.crop(mask, i, j, h, w)
# Random horizontal flipping
if random.random() > 0.5:
image = TF.hflip(image)
mask = TF.hflip(mask)
# Random vertical flipping
if random.random() > 0.5:
image = TF.vflip(image)
mask = TF.vflip(mask)
# Transform to tensor
image = TF.to_tensor(image)
mask = TF.to_tensor(mask)
return image, mask
def __getitem__(self, index):
image = Image.open(self.image_paths[index])
mask = Image.open(self.target_paths[index])
x, y = self.transform(image, mask)
return x, y
def __len__(self):
return len(self.image_paths)