Current situation
I have a multi-label classification problem for which being overconfident is a problem in the end application. The data is labeled with 1 or more from [A, B, C, D, E]
, but in reality e.g. label B should not be treated as 1 or 0, but e.g. 0.7 (unfortunately unattainable).
Normal training
If I would use BCEWithLogitsLoss as normal on data like this:
import torch
loss_func = torch.nn.BCEWithLogitsLoss()
pred = torch.tensor([[0.1, 0.1, .7, 0.2, .7],
[.8, .5, 0.1, 0.1, 0.2]])
target = torch.tensor([[0, 0, 1, 0, 1],
[1, 1, 0, 0, 0]])
loss_func(pred, target.type(torch.FloatTensor))
# tensor(0.6225)
I can successfully train a model. The problem is that the confidence values are like [0.0, 0.0, 0.99, 0.0, 0.98]
.
Goal
I want to say to the the loss function: “If confidence values of correct labels are above >=0.6, and wrong labels below <0.6, don’t calculate a loss”.
Attempt
Set the rows which are correct to binary format, so the loss is 0 for this row.
pred_binary = torch.where(pred >= 0.6, torch.tensor(1), torch.tensor(0))
compare = torch.where(pred_binary == target, torch.tensor(1), torch.tensor(0))
# tensor([[1, 1, 1, 1, 1],
# [1, 0, 1, 1, 1]])
compare_row = compare.type(torch.FloatTensor).mean(axis=1)
# tensor([1.0000, 0.8000])
select_row = torch.where(compare_row >= 1, torch.tensor([1]), torch.tensor([0]))
# tensor([1, 0])
select_row = select_row.type(torch.bool)
# tensor([ True, False])
pred[select_row, :] = pred_binary[select_row].type(torch.FloatTensor)
# tensor([[0.0000, 0.0000, 1.0000, 0.0000, 1.0000],
# [0.8000, 0.5000, 0.1000, 0.1000, 0.2000]])
loss_func(pred, target.type(torch.FloatTensor))
# tensor(0.5838)
The loss is indeed lower (0.6225
versus 0.5838
).
Questions
- Is there a smarter way to do what I want?
- Or more efficient code?
- Any feedback on what I should watch out for doing this?