Pos_weight in binary cross entropy calculation

When we deal with imbalanced training data (there are more negative samples and less positive samples), usually pos_weight parameter will be used.
The expectation of pos_weight is that the model will get higher loss when the positive sample gets the wrong label than the negative sample.
When I use the binary_cross_entropy_with_logits function, I found:

import torch
import torch.nn.functional as F

pos_weight = torch.FloatTensor([5])

preds_pos_wrong =  torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos, pos_weight=pos_weight)

loss_pos_wrong = tensor(2.0359)

preds_neg_wrong =  torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = F.binary_cross_entropy_with_logits(preds_neg_wrong, label_neg, pos_weight=pos_weight)

loss_neg_wrong = tensor(2.0359)

The losses derived from wrong positive samples and negative samples are the same, so how does pos_weight work in the imbalanced data loss calculation?

I don’t quite understand the example as both losses calculate the same (you just swapped the order or the logits and targets), so did you intend to change the logit of the negative class in the second example?

1 Like

Thank you for replying. I just want to set positive label as (1,0) and negative label as (0,1).
If I have 1 positive sample and 10 negative samples, how can I use pos_weight to make the loss derived from positive sample 10 times greater than negative ones?
In the example, preds_pos_wrong and label_pos are the logits for just one case.
For (0.5, 1.5) in preds_pos_wrong, 1.5 is the score for the positive class and 0.5 is the score for the negative class, and its label are (1, 0), which means the possibility of positive class are 1 and the possibility of negative class are 0.
Have I got that right for the usage of preds_pos_wrong or I’m wrong?

I found the solution! The mistake I made is that I use two element to indicate the binary label in pytorch,
the correct format for positive label is 1 and negative label is 0, but I wrote (1,0) and (0,1)

Glad you solved it!

If you allow me a comment:

I think balancing the data by oversampling the minority class / undersampling the majority class is even more common and - at least in my experience - can be more effective.

Best regards


1 Like

Thank you for your suggestion. It sounds great!
I’m modifying my model and I’ll try the sampling method in it.

Wu Shiauthie