How to augment the minority class only in an unbalanced dataset

I have an unbalanced image dataset with the positive class being 1/10 of the entire dataset. Classification models trained on this dataset tend to be biased toward the majority class (small false negative rate and bigger false positive rate). I would like to do some augmentation only on the minority class to deal with this. The way I understand, using transforms (random rotation, etc.) when creating torchvision.dataset will augment all of the data so the imbalance remains the same.
What is the best way to go about only augmenting images from a specific class?

1 Like

You could write your own Dataset and apply the transformations in the __getitem__ method.

class MyData(Dataset):
    def __init__(self, data, target, transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        if (y == 0) and self.transform: # check for minority class
            x = self.transform(x)
        
        return x, y

EDIT: Another useful approach is to use the WeightedRandomSampler and oversample the minority class.

6 Likes

Thank you very much @ptrblck. This will apply random transformations on the images within the minority class without physically changing the size of the minority class, right? I mean, if my dataset has 100 images, 90 of one class and 10 in the other class, I will get the effect of augmentation by iterating many times over it and not by physically having a 90:90 balanced set where 80 minority images were added through transformations. Am I right?
Also, oversampling through WeightedRandomSampler just adds exact copies of the minority to reach the target balanced weights, correct?

You are correct regarding the transformation. The transformation will be applied on the fly on your minority class data.

You are also correct regarding the WeightedRandomSampler, if you are keeping the default replacement=True argument.

1 Like

@ptrblck Can you give an example for how to select the weights?

Sure!
I’ve created some examples here and here.
Let me know, if you need some more explanation!

Dear @ptrblck,
Thank you very much for your explanation.
You mentioned two approaches to overcome the class imbalance, 1) using WeightedRandomSampler and 2) using torchvision transformations on just the minority class. I had a question in this regard:
Unlike the approach1, it is not possible to exactly calculate “the number” of our newly-produced minority class (actually being transformed in each epoch) in the approach 2. Am I right?

I don’t know how well augmenting the minority class is to counter class imbalance training issues and note that my code snippet was a direct response to the topic question. It could work, but I haven’t run any experiments.

The transformation will be applied in each __getitem__ call on the sample if it comes from the minority class. If you are using random transformations I would expect to see no repetitions in these samples and each new epoch will use newly transformed samples for this class.

Understood. Thanks a lot for the clarification.