Changing transformation applied to data during training

I would like to change the transformation I am applying to data during training. For example, I might want to change the size of the random crop I am taking of images from 32 to 28 or change the amount of jitter applied to an image. Is there a way of doing this that works with the DataLoader class when num_workers > 0?

Thanks for the help!

1 Like

You can apply tranformations to the data loader class.

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])])

data_loader = datasets.ImageFolder(root='data/train',
                                           transform=data_transform)

train_loader = torch.utils.data.DataLoader(data_loader,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

Additonal resources on writing custom data loader class can be found here: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#

I’m aware of this. I want to change the transformation that the dataloader is using during training.

2 Likes

You could write your own Dataset and add some logic into the different transformations.
I created a small code snippet using different “stages” for the transformations.
You should apply your logic there, since I’m not sure, when to switch the transforms.
Take care of the different sizes, since your model might complain about a size mismatching! :wink:

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF


class MyData(Dataset):
    def __init__(self):
        self.images = [TF.to_pil_image(x) for x in torch.ByteTensor(10, 3, 48, 48)]
        self.set_stage(0) # initial stage
        
    def __getitem__(self, index):
        image = self.images[index]
        
        # Just apply your transformations here
        image = self.crop(image)
        x = TF.to_tensor(image)
        return x
        
    def set_stage(self, stage):
        if stage == 0:
            print('Using (32, 32) crops')
            self.crop = transforms.RandomCrop((32, 32))
        elif stage == 1:
            print('Using (28, 28) crops')
            self.crop = transforms.RandomCrop((28, 28))
        
    def __len__(self):
        return len(self.images)


dataset = MyData()
loader = DataLoader(dataset,
                    batch_size=2,
                    num_workers=2,
                    shuffle=True)

for batch_idx, data in enumerate(loader):
    print('Batch idx {}, data shape {}'.format(
        batch_idx, data.shape))
    
loader.dataset.set_stage(1)

for batch_idx, data in enumerate(loader):
    print('Batch idx {}, data shape {}'.format(
        batch_idx, data.shape))
8 Likes

Thank you, this is exactly what I need!

I know that the solution you provided works, but what I am confused is why if I actually have a float Tensor then it doesn’t work:

inputs = transform(inputs)

Error:

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.FloatTensor'>

I see what the error says but why wouldn’t transforms work for FloatTensors? It seems so counter intuitive to me.

Currently only PIL Images are supported, so you would have to transform it to an Image before calling transform.
I think the reason is, that the transformations rely heavily on PIL functions and thus cannot be applied on Tensors so easily.

1 Like

put “transforms.ToTensor(),” after you do other transforms.

Exactly. If I tranform the tensor to pil and viceversa then the final result is not accurate. How do I solve the problem. I will be highly obliged if you create for mnist dataset. For my test cases, the data loader gives three outputs data, transformation(data) and target.

Could you check the data type of your input?
We have a similar issue recently here, where the inputs were FloatTensors and thus an overflow occurred.

what if I already have a dataset object? would:

valset.transform = new_transform

work?

It depends on your workflow.
The manipulation itself would work and valset would use the new_transform when self.transform is called. However, if you are wrapping valset into a DataLoader using multiple workers, you have to be careful when (and if) this change will be visible.
When you start iterating the DataLoader, each worker will create a copy of the Dataset until the loop finishes. Changing the valset via loader.dataset.transform = new_transform would then be visible in the next epoch (or when you restart the DataLoader loop). Also, if you are using persistent_workers=True, the workers would never restart (and thus also never create a new copy of the dataset) and thus the change won’t be used.

1 Like

Thanks!

I think a good solution can be found here: Changing transforms after creating a dataset - #7 by Brando_Miranda

train_dataset = MyDataset(train_transform)
val_dataset = MyDataset(val_transform)
train_indices, val_indices = sklearn.model_selection.train_test_split(indices)
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)