Yes, from what you’ve described, pos_weight
does seem suitable for your scenario. Below I’ll illustrate that it does indeed assign more weight to the positive class, inside the loss calculation.
Before I do that, I just want to mention that you can also achieve a similar effect by using a the vanilla loss function (without pos_weight
) but just sampling more frequently from your desired label by using a WeightedRandomSampler. Personally I prefer this approach, if your objective here is simply to deal with a label imbalance within the class (and you could also use it to deal with class imbalance, though in your case that’s not required since you have a single class). I think the two approaches might be equivalent under some particular parameters and assumptions (someone can correct me if that’s wrong) but intuitively if your positives are very much in the minority, and you try to address that via pos_weight
, I feel like your training will be very jumpy. Your training batches either won’t have any positive labels, but when they do you’ll make a very large step in that direction. This feels less robust to me (and a bit less intuitive to think about) than just balancing out your sampling, by sampling from the positives more frequently.
Having said that, here’s an illustration that pos_weight
does indeed represent a weight multiplier on the positive label.
label = torch.Tensor([ 0., +1., 0., +1.])
prediction = torch.Tensor([-10., -10., +10., +10.])
# the labels are [ neg, pos, neg, pos]
# model's predictions are [correct, wrong, wrong, correct]
crit_basic = torch.nn.BCEWithLogitsLoss(reduction='none')
crit_posweight = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.Tensor([2.]))
print("losses by element:")
print(" basic", crit_basic(prediction, label))
print("pos_weight", crit_posweight(prediction, label))
Output:
losses by element:
basic tensor([4.5418e-05, 1.0000e+01, 1.0000e+01, 4.5418e-05])
pos_weight tensor([4.5776e-05, 2.0000e+01, 1.0000e+01, 9.0835e-05])
In this setup, your model happens to be right half the time and wrong half the time. As a reminder, a model prediction of -10 means it expects a negative label, and a prediction of +10 means it expects a positive label (because sigmoid(-10) ~= 0 and sigmoid(+10) ~= +1).
You can see that your vanilla loss (excluding pos_weight
) is 0 when the model is right and 1 when the model is wrong. It doesn’t care whether the model was wrong about the positive or the negative label, they are weighted the same.
The loss that uses pos_weight
is still equal to 1 when the label was negative, (third element in the tensor) however it has doubled when the label is positive, the 2nd element in the tensor, going from 1 to 2. The doubling corresponds to the pos_weight
passed, which is 2. Indeed, your loss function now assigns more weight to a mistake made on a positive label, but otherwise behaves the same as before.