Classification on class-imbalanced data sets?


I have a dataset with 3 classes with the following items:

  • Class 1: 900 elements
  • Class 2: 15000 elements
  • Class 3: 800 elements

I need to predict class 1 and class 3, which signal important deviations from the norm. Class 2 is the default “normal” case which I don’t care about.

What kind of loss function would I use here?
I was thinking of using CrossEntropyLoss, but since there is a class imbalance, this would need to be weighted I suppose? How does that work in practice? Like this?

summed = 900 + 150000 + 800
weight = torch.tensor([900 / summed, 150000 / summed, 800 / summed])
crit = nn.CrossEntropyLoss(weight=weight)

Or should the weight be inverted? i.e. 1 / weight?

Is this the right approach to begin with or are there other / better methods I could use?


You could start by setting the weights as the inverse of the class counts or the inverse of your current weight tensor.
In your current approach you would weight the majority class the highest.
Also, you could try to oversample the minority classes using WeightedRandomSampler.
I’ve created a dummy example here.

WeightedRandomSampler seems like a good idea. Thank you