Transformations work with Subset
s as seen here:
# not recommended approach of setting the transform attribute late
dataset = datasets.MNIST(root='/home/pbialecki/python/data', train=True, transform=None, download=False)
print(type(dataset[0][0]))
# > <class 'PIL.Image.Image'>
dataset = torch.utils.data.Subset(dataset, indices=torch.arange(10))
print(type(dataset[0][0]))
# > <class 'PIL.Image.Image'>
dataset.dataset.transform = transforms.Compose([
transforms.RandomResizedCrop(20),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
print(type(dataset[0][0]))
# > <class 'torch.Tensor'>
# recommended approach
transform = transforms.Compose([
transforms.RandomResizedCrop(20),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST(root='/home/pbialecki/python/data', train=True, transform=transform, download=False)
print(type(dataset[0][0]))
# > <class 'torch.Tensor'>
dataset = torch.utils.data.Subset(dataset, indices=torch.arange(10))
print(type(dataset[0][0]))
# > <class 'torch.Tensor'>
As previously described: if you want to add the transformation to the internal .transform
attribute late (I would not recommend to use this approach) you would have to index the internal dataset
via dataset.dataset.transform
.