Different Transformations for Different Classes

I would like to perform different transformations on 5 different classes in a multi-class classification problem. What’s the cleanest way to do this using ImageFolder and DataLoader?

Note: The classes also have quite different sizes, ranging from 50-3,000 images
How can I set this up to not over fit on the classes with more data?

My data structure looks like this:

/data
    /train
        /class1
        /class2
        /class3
        /class4
        /class5
    /test
        /class1
        /class2
        /class3
        /class4
        /class5

The easiest way that comes to my mind at the moment is to create a new Dataset instance and explicitly use different transformations on different classes.
For the unbalanced dataset, you may want to use WeightedRandomSampler

Currently my setup is something like this:

class ConcatDataset(Dataset):
    '''Concatenates Datasets'''
    
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

train_1_transforms = ...
train_2_transforms = ...
train_3_transforms = ...
train_4_transforms = ...
train_5_transforms = ...

train_1_data = ImageFolder(train_1_dir, transform=train_1_transforms)
train_2_data = ImageFolder(train_2_dir, transform=train_2_transforms)
train_3_data = ImageFolder(train_3_dir, transform=train_3_transforms)
train_4_data = ImageFolder(train_4_dir, transform=train_4_transforms)
train_5_data = ImageFolder(train_5_dir, transform=train_5_transforms)

train_concat_data = ConcatDataset(train_1_data, train_2_data, train_3_data, train_4_data, train_5_data)

train_loader = DataLoader(train_concat_data, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)

I’m getting error:

RuntimeError: Found 0 files in subfolders of: ./data/train/class_1

I don’t have any idea how else to apply separate transformations to each class.

Dont use ImageFolder, just read the images using os.walk for example and store their paths and labels (based on which class they reside in you can give them a number) and then in __getitem__, after you have read the image, based on the label that you also have, you can apply the needed transformation and then return the pair.