Yes, since it might not be trivial to apply weights to a multi-label classification use case.
Let me give you an example.
In a multi-class classification you can directly apply a class weight to the corresponding sample as seen here:
# multi-class classification
batch_size = 10
nb_classes = 4
logits = torch.randn(batch_size, nb_classes, requires_grad=True)
targets = torch.randint(0, nb_classes, (batch_size,))
weights = torch.rand(nb_classes)
print(targets)
# tensor([2, 3, 3, 0, 1, 2, 3, 2, 0, 1])
print(weights)
# tensor([0.9253, 0.1432, 0.8336, 0.9465])
weighted_criterion = nn.CrossEntropyLoss(weight=weights, reduction="mean")
loss = weighted_criterion(logits, targets)
print(loss)
# tensor(2.6470, grad_fn=<NllLossBackward0>)
raw_criterion = nn.CrossEntropyLoss(reduction="none")
loss_raw = raw_criterion(logits, targets)
print(loss_raw)
# tensor([2.3437, 3.4518, 3.3348, 2.1393, 0.9009, 5.2935, 1.5514, 0.5305, 2.6735,
# 3.5571], grad_fn=<NllLossBackward0>)
loss_weighted = (loss_raw * weights[targets] / weights[targets].sum()).sum()
print(loss_weighted)
# tensor(2.6470, grad_fn=<SumBackward0>)
Indexing the weights
tensor with the targets
works fine and returns the expected loss as verified in my manual comparison.
However, in a multi-label classification use case each sample can belong to zero, one, or multiple classes as seen here:
# multi-label classification
targets = torch.randint(0, 2, (batch_size, nb_classes))
print(targets)
# tensor([[1, 0, 1, 1],
# [1, 1, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 1, 1],
# [0, 1, 1, 0],
# [1, 0, 0, 0],
# [0, 0, 0, 1],
# [1, 0, 0, 1],
# [1, 0, 0, 1],
# [1, 0, 0, 0]])
raw_criterion = nn.BCEWithLogitsLoss(reduction="none")
loss_raw = raw_criterion(logits, targets.float())
print(loss_raw)
# tensor([[0.7298, 0.3813, 1.4581, 0.5213],
# [1.4387, 0.5741, 0.6803, 2.5363],
# [0.3412, 0.3628, 0.1299, 1.4728],
# [0.9664, 1.2701, 0.2159, 2.8560],
# [0.6312, 0.4537, 0.6042, 0.3793],
# [0.2912, 2.0887, 0.0754, 1.8689],
# [1.0385, 2.0379, 0.8648, 0.3196],
# [0.7231, 0.1918, 1.3323, 0.8123],
# [0.8825, 0.9056, 1.9396, 0.3908],
# [0.7225, 0.3262, 0.3110, 2.5515]],
# grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
print(weights)
# tensor([0.9253, 0.1432, 0.8336, 0.9465])
It’s now unclear to me how you would like to apply the weights.
E.g. take the first sample with a target of [1, 0, 1, 1]
, which means classes 0, 2, and 3 are “active”.
Would you sum the corresponding weights and multiply it directly with the unreduced loss?