I used this formula, and it led me to very good results:
Max(Number of occurrences in most common class) / (Number of occurrences in rare classes)
- In multiclass problems, if we have five classes with this number of occurrences
0 2741
1 37919
2 22858
3 31235
4 5499
the weight value for each class is 0:(37919/2741), 1:(37919/37919), 2:(37919/22858), 3:(37919/31235), 4:(37919/5499), so:
weights = [13.83, 1.0, 1.66, 1.21, 6.9]
class_weights = torch.FloatTensor(weights).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
- In binary classification problems, if we have two classes respectively
0 900
1 100
you can use nn.BCEWithLogitsLosspos_weight
parameter, which takes as input the positive class weight (in this case 900/100 = 9), so:
weight = [9.0]
class_weight = torch.FloatTensor(weight).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight= class_weight)