Cropping batches at the same position

You could use a counter to chose to resample the random crop parameters or reuse them.
Here is a small (untested) example:

class MyDataset(Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths
        self.crop_indices = []

    def transform(self, image, resample):
        # Resize
        resize = transforms.Resize(size=(520, 520))
        image = resize(image)

        # Random crop
        if resample:
            self.crop_indices = transforms.RandomCrop.get_params(
                image, output_size=(512, 512))
        i, j, h, w = self.crop_indices
        image = TF.crop(image, i, j, h, w)

        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)

        # Random vertical flipping
        if random.random() > 0.5:
            image = TF.vflip(image)

        # Transform to tensor
        image = TF.to_tensor(image)
        return image

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        resample = index % 2 == 0
        x = self.transform(image, resample)
        return x 

    def __len__(self):
        return len(self.image_paths)

Let me know, if that works for you.

2 Likes