Yeah, the documentation might be a bit misleading, as size_average
is mentioned, while this argument is deprecated and reduction should be used instead.
There are basically the three reduction types, i.e. 'none'
, 'sum'
, and 'mean'
.
ignore_index
will be applied as:
-
reduction='none'
: the loss will not be reduced, so you will get a loss tensor in the shape[batch_size]
. The entries, wheretarget==ignore_index
, will have a zero loss. -
reduction='mean'
: the reduced loss will be the average of all entries, wheretarget!=ignore_index
. -
reduction='sum'
: the reduced loss will be the sum of the “raw loss”. Since the samples with ignored targets will get a zero loss, the sum should not change, if you filter them out or just sum over all values.
Here is a small code snippet to demonstrate my understanding:
# Setup
output = torch.randn(10, 10, requires_grad=True)
target = torch.arange(10)
# sanity check for plain loss without ignore_index
loss_raw = F.cross_entropy(output, target, reduction='none')
loss_mean = F.cross_entropy(output, target, reduction='mean')
print(loss_raw.mean() == loss_mean)
loss_sum = F.cross_entropy(output, target, reduction='sum')
print(loss_raw.sum() == loss_sum)
# Case 2: ignore_index=0, reduction='mean'
loss_raw_ignore = F.cross_entropy(
output, target, reduction='none', ignore_index=0)
loss_mean_ignore = F.cross_entropy(
output, target, reduction='mean', ignore_index=0)
print(loss_mean_ignore == loss_raw_ignore[loss_raw_ignore!=0].mean())
# Check gradients
output.grad = None
loss_mean_ignore.backward()
g0 = output.grad.clone()
output.grad = None
loss_raw_ignore[loss_raw_ignore!=0].mean().backward(retain_graph=True)
g1 = output.grad.clone()
print((g0 == g1).all())
# Case 3: ignore_index=0, reduction='sum'
loss_sum_ignore = F.cross_entropy(
output, target, reduction='sum', ignore_index=0)
print(loss_sum_ignore == loss_raw_ignore.sum())
# Check gradients
output.grad = None
loss_sum_ignore.backward()
g0 = output.grad.clone()
output.grad = None
loss_raw_ignore.sum().backward(retain_graph=True)
g1 = output.grad.clone()
output.grad = None
loss_raw_ignore[loss_raw_ignore!=0].sum().backward()
g2 = output.grad.clone()
print((g0 == g1).all() and (g0 == g2).all())
Let me know, if I’m missing something.