Some doubts on the use of Transforms with a custom Dataset

I’ve been experimenting with Transforms, and it seems that when we pass a picture to a transform like RandomPerspective, or CenterCrop, it only outputs 1 picture, where the original is lost.

So, we’re losing the original picture.

My idea of data augmentation was of training on augmented_data = original_data + transformed_data… but it seems that if we use transforms, we’ll be training only on the transformed_data.

  1. How would we work in PyTorch with the augmented_data as defined above when implementing my own custom Dataset?

Some transforms, like FiveCrop output more than 1 picture. According to this PyTorch documentation page, it seems to recommend to use torch.stack to stack the ‘tensorized’ pictures. However, when I do that I’m creating an extra dimension on my final tensor. So, instead of a 3D (unbatched) or 4D (batched) tensor, I’ll have a 4D(unbatched) / 5D (batched), where in the initial dimension, I’ll have the number of pictures outputted by the the transforms.

  1. How does one work with Transforms like FiveCrop, which output more than 1 picture, when creating our custom Dataset?
  1. I’m not sure what you mean by “losing” the original. Transforms should not be mutating the data in the dataset itself. If you mean that only the transformed data will make into the batch, then this is expected, especially in the context of multi-epoch training where the same original images will be transformed in different ways according to randomly generated augmentation parameters such as the crop size or other distortions. If you wish to sometimes use the original data, the simplest thing to do would be to create a copy of your dataset that does not apply any transformations and to selectively load data from that dataset alongside your dataset that does apply transformations. You could also probabilistically disable some transformations e.g., via: RandomApply — Torchvision main documentation (pytorch.org)

  2. It depends on the task you are training for; for a simple classification task, you could just repeat the label assuming the label should be the same for each of the five crops, etc.

@eqy Thanks for the reply!

  1. It now makes a bit more sense. :wink:
  2. Could you give me a short example/reference of how to incorporate your suggestion in a custom Dataset being loaded by a Dataloader? I’ve tried something similar, but it didn’t work.

I think my earlier reply was inaccurate, you wouldn’t necessarily need a custom dataset. Here’s an example using the existing CIFAR10 dataset in torchvision:

import torch
import torchvision

dataset = torchvision.datasets.CIFAR10(root='.', transform=torchvision.transforms.Compose([torchvision.transforms.FiveCrop(24),
                                                                                 torchvision.transforms.Lambda(lambda crops: torch.stack([torchvision.transforms.PILToTensor()(crop) for crop in crops]))]), target_transform=torchvision.transforms.Lambda(lambda label: torch.ones(5, dtype=torch.int)*label), download=True)
print(dataset[0][0].shape, dataset[0][1].shape)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, collate_fn=lambda data: (torch.concat([i[0] for i in data]), torch.concat([i[1] for i in data])))
for input, target in dataloader:
    print(input.shape)
    print(target.shape)
    print(target)
    break