How to augment the minority class only in an unbalanced dataset

I have an unbalanced image dataset with the positive class being 1/10 of the entire dataset. Classification models trained on this dataset tend to be biased toward the majority class (small false negative rate and bigger false positive rate). I would like to do some augmentation only on the minority class to deal with this. The way I understand, using transforms (random rotation, etc.) when creating torchvision.dataset will augment all of the data so the imbalance remains the same.
What is the best way to go about only augmenting images from a specific class?

1 Like

You could write your own Dataset and apply the transformations in the __getitem__ method.

class MyData(Dataset):
    def __init__(self, data, target, transform=None): = data = target
        self.transform = transform
    def __len__(self):
        return len(
    def __getitem__(self, index):
        x =[index]
        y =[index]
        if (y == 0) and self.transform: # check for minority class
            x = self.transform(x)
        return x, y

EDIT: Another useful approach is to use the WeightedRandomSampler and oversample the minority class.


Thank you very much @ptrblck. This will apply random transformations on the images within the minority class without physically changing the size of the minority class, right? I mean, if my dataset has 100 images, 90 of one class and 10 in the other class, I will get the effect of augmentation by iterating many times over it and not by physically having a 90:90 balanced set where 80 minority images were added through transformations. Am I right?
Also, oversampling through WeightedRandomSampler just adds exact copies of the minority to reach the target balanced weights, correct?

You are correct regarding the transformation. The transformation will be applied on the fly on your minority class data.

You are also correct regarding the WeightedRandomSampler, if you are keeping the default replacement=True argument.

1 Like

@ptrblck Can you give an example for how to select the weights?

I’ve created some examples here and here.
Let me know, if you need some more explanation!