How to augment the minority class only in an unbalanced dataset

You could write your own Dataset and apply the transformations in the __getitem__ method.

class MyData(Dataset):
    def __init__(self, data, target, transform=None):
        self.data = data
        self.target = target
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        if (y == 0) and self.transform: # check for minority class
            x = self.transform(x)
        
        return x, y

EDIT: Another useful approach is to use the WeightedRandomSampler and oversample the minority class.

6 Likes