How to apply another transform to an existing Dataset?

Transformations work with Subsets as seen here:

# not recommended approach of setting the transform attribute late
dataset = datasets.MNIST(root='/home/pbialecki/python/data', train=True, transform=None, download=False)

print(type(dataset[0][0]))
# > <class 'PIL.Image.Image'>

dataset = torch.utils.data.Subset(dataset, indices=torch.arange(10))
print(type(dataset[0][0]))
# > <class 'PIL.Image.Image'>

dataset.dataset.transform = transforms.Compose([
                transforms.RandomResizedCrop(20),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])


print(type(dataset[0][0]))
# > <class 'torch.Tensor'>

# recommended approach
transform = transforms.Compose([
    transforms.RandomResizedCrop(20),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST(root='/home/pbialecki/python/data', train=True, transform=transform, download=False)

print(type(dataset[0][0]))
# > <class 'torch.Tensor'>

dataset = torch.utils.data.Subset(dataset, indices=torch.arange(10))
print(type(dataset[0][0]))
# > <class 'torch.Tensor'>

As previously described: if you want to add the transformation to the internal .transform attribute late (I would not recommend to use this approach) you would have to index the internal dataset via dataset.dataset.transform.

1 Like