What is the weight values mean in torch.nn.CrossEntropyLoss?

In official docs, weight is used for unbalanced training set. In CrossEntropyLoss, what is the weight values mean??

4 Likes

It just means the weight that you give to different classes. Basically, for classes with small number of training images, you give it more weight so that the network will be punished more if it makes mistakes predicting the label of these classes. For classes with large numbers of images, you give it small weight.

4 Likes

So, how can i decide that value?? It is proportional to size of classes??

1 Like

If you have a roughly balanced problem, you don’t need to use weights. If you have imbalance, 1/class_size is a typical choice.

Best regards

Thomas

2 Likes

what if class 0 has 9000 samples and class 1 has 41000 samples, in this case, are weight of class0 1/9000 and weight of class1 1/41000 ??

1 Like

It is not definitely. You can try multiple weights and see which works best.

ok thank you for your advice

What should the weights be to get better accuracy ? In a 5 classes data I used something like that and the results got much worst. Can you give an intuition ?

weights = [1/1016, 1/12852, 1/12888, 1/3380, 1/296] #[ 1 / number of instances for each class]
class_weights = torch.FloatTensor(weights).cuda()

criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(parameters)
2 Likes

I used this formula, and it led me to very good results:

Max(Number of occurrences in most common class) / (Number of occurrences in rare classes)
  • In multiclass problems, if we have five classes with this number of occurrences
    0 2741
    1 37919
    2 22858
    3 31235
    4 5499
    the weight value for each class is 0:(37919/2741), 1:(37919/37919), 2:(37919/22858), 3:(37919/31235), 4:(37919/5499), so:
weights = [13.83, 1.0, 1.66, 1.21, 6.9]
class_weights = torch.FloatTensor(weights).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
  • In binary classification problems, if we have two classes respectively
    0 900
    1 100
    you can use nn.BCEWithLogitsLoss pos_weight parameter, which takes as input the positive class weight (in this case 900/100 = 9), so:
weight = [9.0]
class_weight = torch.FloatTensor(weight).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight= class_weight)
17 Likes

Keep in mind that in a bigger than 90/10 unbalance setting, you will be presenting your network with more than 90 % cases of fairly small losses (weighted by 1/n_samples). While once in a while the other class of less than 10 % will pop up with a huge loss resulting in a relatively huge update step, forcing you to nevertheless stay at a moderate learning rate.

Now if you combine this with some sophisticated augmentation pipeline as is often necessary, this can become a real waste of resources (electricity bill).

I think for anything bigger than tiny datasets it is can be recommended to upsample the minority classes. This can be achieved with light memory footprint by using the weighted samplers provided with pytorch.

As a side note: an additional quite natural case where class weights come up is, when on deployment we know that the economic cost of a misclassification of class (a) is much higher than the economic cost of class (b). Then we can just take that cost (2 to 1) and put it into our CrossEntropyLoss [2.0, 1.0].

Can this be applied in F.binary_cross_entropy_loss (unbalanced binary classification problem?)