Small addition to @MariosOreo’s answer:
if your loss function uses reduction='mean'
, the loss will be normalized by the sum of the corresponding weights for each element. If you are using reduction='none'
, you would have to take care of the normalization yourself.
Here is a small example:
x = torch.randn(10, 5)
target = torch.randint(0, 5, (10,))
weights = torch.tensor([1., 2., 3., 4., 5.])
criterion_weighted = nn.CrossEntropyLoss(weight=weights)
loss_weighted = criterion_weighted(x, target)
criterion_weighted_manual = nn.CrossEntropyLoss(weight=weights, reduction='none')
loss_weighted_manual = criterion_weighted_manual(x, target)
loss_weighted_manual = loss_weighted_manual.sum() / weights[target].sum()
print(loss_weighted == loss_weighted_manual)