Apply different Transform (Data Augmentation) to Train and Validation

My dataset folder is prepared as Train Folder and Test Folder. When I conduct experiments, I further split my Train Folder data into Train and Validation.

However, transform is applied before my split and they are the same for both my Train and Validation. My question is how to apply a different transform in this case?

Transoform Code:

data_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

Dataset Code

train_data = datasets.ImageFolder(base_path + '/train/',
                                           transform=data_transform)

Train and Validation Split and Loader Code

# obtain training indices that will be used for validation
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# prepare data loaders 
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    sampler=valid_sampler, num_workers=num_workers)
3 Likes

In your case your have 1 dataset and 2 samplers.
I usually split the dataset into a training and validation dataset instead:

tng_dataset = torch.utils.data.Subset(train_data, train_idx)
val_dataset = torch.utils.data.Subset(train_data, valid_idx)

Then instead of applying the transformation when creating the ImageFolder dataset, you can apply it to the individual splitted dataset using such a helpful class:

class MapDataset(torch.utils.data.Dataset):
    """
    Given a dataset, creates a dataset which applies a mapping function
    to its items (lazily, only when an item is called).

    Note that data is not cloned/copied from the initial dataset.
    """

    def __init__(self, dataset, map_fn):
        self.dataset = dataset
        self.map = map_fn

    def __getitem__(self, index):
        return self.map(self.dataset[index])

    def __len__(self):
        return len(self.dataset)

Note that here, the mapping function is applied on the output of the dataset (which might be both the input and the target). You might want either to pass a mapping function that handle this or to modify the class to your needs.

and finally:

tng_data_tf = MapDataset(tng_data, data_transform)
train_loader = torch.utils.data.DataLoader(tng_data_tf, batch_size=batch_size, num_workers=num_workers)
4 Likes

Thanks.

Should I still use ImageFolder to obtain the variable train_data?
For example train_data = datasets.ImageFolder(base_path + '/train/') without transfrom

1 Like

You’re welcome :slight_smile:

Should I still use ImageFolder to obtain the variable train_data ?

Yes, exactly!

If this solves your issue please mark it as solved with the previous answer as solution :slight_smile:

Thanks for your answer. I will try and confirm it ASAP.

Hi gregunz, apologize for late reply.
You code is great, but it needs to change a bit in ‘getitem’ to access iamges and labels in my case.
I have used your code and the code here Using ImageFolder, random_split with multiple transforms.

The resulting code coulde work for me. Let me know if my I did it correctly. You can refine this code if there are any mistakes and then I will accept it as a solution.

class MapDataset(torch.utils.data.Dataset):
    """
    Given a dataset, creates a dataset which applies a mapping function
    to its items (lazily, only when an item is called).

    Note that data is not cloned/copied from the initial dataset.
    """

    def __init__(self, dataset, map_fn):
        self.dataset = dataset
        self.map = map_fn

    def __getitem__(self, index):
        if self.map:     
            x = self.map(self.dataset[index][0]) 
        else:     
            x = self.dataset[index][0]  # image
        y = self.dataset[index][1]   # label      
        return x, y

    def __len__(self):
        return len(self.dataset)

2 Likes

It doesn’t work for me, it doesn’t apply anything, it returns the pictures as before.

Can you show me an example of function to pass through this class.
How can I apply classical trnsormations like toTensor(), with this method?

The best solution is actually to load the dataset twice, and then apply different transformation for each. I haven’t test teh following code, but hope you can get the idea. The key idea is that the train_data and valida_data are exactly the same, but only different in data_transform.

train_data = datasets.ImageFolder(base_path + '/train/',
                                           transform=data_transform_train)
valid_data = datasets.ImageFolder(base_path + '/train/',
                                           transform=data_transform_val)

# obtain training indices that will be used for validation
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# prepare data loaders 
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    sampler=valid_sampler, num_workers=num_workers)

1 Like

Hay gregunz and xinqi,
thanks for this amazing piece of code :slight_smile: helped me a lot!
Greetings,
Nico

Shouldn’t the last torch.utils.data.DataLoader(valid_data, …) be the last torch.utils.data.DataLoader(train_data, …)??

I think a good solution can be found here: Changing transforms after creating a dataset - #7 by Brando_Miranda

train_dataset = MyDataset(train_transform)
val_dataset = MyDataset(val_transform)
train_indices, val_indices = sklearn.model_selection.train_test_split(indices)
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)