What is the weight values mean in torch.nn.CrossEntropyLoss?

I used this formula, and it led me to very good results:

Max(Number of occurrences in most common class) / (Number of occurrences in rare classes)
  • In multiclass problems, if we have five classes with this number of occurrences
    0 2741
    1 37919
    2 22858
    3 31235
    4 5499
    the weight value for each class is 0:(37919/2741), 1:(37919/37919), 2:(37919/22858), 3:(37919/31235), 4:(37919/5499), so:
weights = [13.83, 1.0, 1.66, 1.21, 6.9]
class_weights = torch.FloatTensor(weights).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
  • In binary classification problems, if we have two classes respectively
    0 900
    1 100
    you can use nn.BCEWithLogitsLoss pos_weight parameter, which takes as input the positive class weight (in this case 900/100 = 9), so:
weight = [9.0]
class_weight = torch.FloatTensor(weight).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight= class_weight)
17 Likes