I have a dataset with 3 classes with the following items:
- Class 1: 900 elements
- Class 2: 15000 elements
- Class 3: 800 elements
I need to predict class 1 and class 3, which signal important deviations from the norm. Class 2 is the default “normal” case which I don’t care about.
What kind of loss function would I use here?
I was thinking of using CrossEntropyLoss, but since there is a class imbalance, this would need to be weighted I suppose? How does that work in practice? Like this?
summed = 900 + 150000 + 800 weight = torch.tensor([900 / summed, 150000 / summed, 800 / summed]) crit = nn.CrossEntropyLoss(weight=weight)
Or should the weight be inverted? i.e. 1 / weight?
Is this the right approach to begin with or are there other / better methods I could use?