UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead

Yeah, the documentation might be a bit misleading, as size_average is mentioned, while this argument is deprecated and reduction should be used instead.

There are basically the three reduction types, i.e. 'none', 'sum', and 'mean'.
ignore_index will be applied as:

  1. reduction='none': the loss will not be reduced, so you will get a loss tensor in the shape [batch_size]. The entries, where target==ignore_index, will have a zero loss.
  2. reduction='mean': the reduced loss will be the average of all entries, where target!=ignore_index.
  3. reduction='sum': the reduced loss will be the sum of the “raw loss”. Since the samples with ignored targets will get a zero loss, the sum should not change, if you filter them out or just sum over all values.

Here is a small code snippet to demonstrate my understanding:

# Setup
output = torch.randn(10, 10, requires_grad=True)
target = torch.arange(10)

# sanity check for plain loss without ignore_index
loss_raw = F.cross_entropy(output, target, reduction='none')

loss_mean = F.cross_entropy(output, target, reduction='mean')
print(loss_raw.mean() == loss_mean)

loss_sum = F.cross_entropy(output, target, reduction='sum')
print(loss_raw.sum() == loss_sum)

# Case 2: ignore_index=0, reduction='mean'
loss_raw_ignore = F.cross_entropy(
    output, target, reduction='none', ignore_index=0)

loss_mean_ignore = F.cross_entropy(
    output, target, reduction='mean', ignore_index=0)
print(loss_mean_ignore == loss_raw_ignore[loss_raw_ignore!=0].mean())

# Check gradients
output.grad = None
loss_mean_ignore.backward()
g0 = output.grad.clone()

output.grad = None
loss_raw_ignore[loss_raw_ignore!=0].mean().backward(retain_graph=True)
g1 = output.grad.clone()
print((g0 == g1).all())

# Case 3: ignore_index=0, reduction='sum'
loss_sum_ignore = F.cross_entropy(
    output, target, reduction='sum', ignore_index=0)
print(loss_sum_ignore == loss_raw_ignore.sum())

# Check gradients
output.grad = None
loss_sum_ignore.backward()
g0 = output.grad.clone()

output.grad = None
loss_raw_ignore.sum().backward(retain_graph=True)
g1 = output.grad.clone()

output.grad = None
loss_raw_ignore[loss_raw_ignore!=0].sum().backward()
g2 = output.grad.clone()
print((g0 == g1).all() and (g0 == g2).all())

Let me know, if I’m missing something.

4 Likes