How to apply another transform to an existing Dataset?

This is a code example:

dataset = datasets.MNIST(root=root, train=istrain, transform=None)  #preserve raw img

print(type(dataset[0][0]))
# <class 'PIL.Image.Image'>

dataset = torch.utils.data.Subset(dataset, indices=SAMPLED_INDEX) # for resample

for ind in range(len(dataset)):
    img, label = dataset[ind] # <class 'PIL.Image.Image'> <class 'int'>/<class 'numpy.int64'>
    img.save(fp=os.path.join(saverawdir, f'{ind:02d}-{int(label):02d}.png'))

dataset.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
#transform for net forwarding

print(type(dataset[0][0]))
# expected <class 'torch.Tensor'>, however it's still <class 'PIL.Image.Image'>

Since dataset is randomly resampled, I don’t want to reload a new dataset with transform, but just apply transform to the already existing dataset.

Thanks for your help :smiley:

2 Likes

Subset will wrap the passed Dataset in the .dataset attribute, so you would have to add the transformation via:

dataset.dataset.transform = transforms.Compose([
                transforms.RandomResizedCrop(28),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])

Your code should work without the usage of Subset.

4 Likes

I using the same code but its not working for me…
I have initially created a dataset from ImageFolder and than used random_split to get the train and val sets and than I am trying to apply transformations with the method above but its not working

I met same problem. I think transformation can’t be applied to dataset.Subset. If you solved this problem, please let me know.

Transformations work with Subsets as seen here:

# not recommended approach of setting the transform attribute late
dataset = datasets.MNIST(root='/home/pbialecki/python/data', train=True, transform=None, download=False)

print(type(dataset[0][0]))
# > <class 'PIL.Image.Image'>

dataset = torch.utils.data.Subset(dataset, indices=torch.arange(10))
print(type(dataset[0][0]))
# > <class 'PIL.Image.Image'>

dataset.dataset.transform = transforms.Compose([
                transforms.RandomResizedCrop(20),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])


print(type(dataset[0][0]))
# > <class 'torch.Tensor'>

# recommended approach
transform = transforms.Compose([
    transforms.RandomResizedCrop(20),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST(root='/home/pbialecki/python/data', train=True, transform=transform, download=False)

print(type(dataset[0][0]))
# > <class 'torch.Tensor'>

dataset = torch.utils.data.Subset(dataset, indices=torch.arange(10))
print(type(dataset[0][0]))
# > <class 'torch.Tensor'>

As previously described: if you want to add the transformation to the internal .transform attribute late (I would not recommend to use this approach) you would have to index the internal dataset via dataset.dataset.transform.

1 Like

Thanks for your more detailed explanation.
and can I know why you don’t recommend the first approach?

Manipulating the internal .transform attribute assumes that self.transform is indeed used to apply the transformations.
While this might be the case for e.g. MNIST other datasets could use other attributes (e.g. self.image_fransform) and you would need to add this manipulation according to the real implementation (which could of course also change between releases).
The right approach is thus to set the transformations once during the initialization of the Dataset and allow the Dataset to handle the transformations internally without depending on its actual implementation.

1 Like