Does the `reduction` argument of `cross_entropy` works?

Maybe I’m stupid but I don’t think it works.

Consider this fragment:

probs = torch.tensor([0.8, 0.1, 0.05, 0.05])
target = torch.tensor([0.9, 0.0, 0.1, 0.0])
assert probs.shape==target.shape
assert probs.sum()==1.0

Then we compute the binary cross entropy with reduction='none' and try to replicate the results manually:

print(F.binary_cross_entropy(input=probs, target=target, reduction='none'))
bce_loss = -((target * torch.log(probs)) + ((1-target) * torch.log(1-probs)))
print(bce_loss)

tensor([0.3618, 0.1054, 0.3457, 0.0513])
tensor([0.3618, 0.1054, 0.3457, 0.0513])

Then we try reduction='mean' and reduction='sum':

print(F.binary_cross_entropy(input=probs, target=target, reduction='mean'))
print(bce_loss.mean())

tensor(0.8406)
tensor(0.8406)

print(F.binary_cross_entropy(input=probs, target=target, reduction='sum'))
print(bce_loss.sum())

tensor(3.3625)
tensor(3.3625)

Instead, cross_entropy seems to ignore the reductionargument altogether and always sum?

logits = torch.tensor([10.0, 20.0, -5.0, 6.0])
target = torch.tensor([0.9, 0.0, 0.1, 0.0])
print(F.cross_entropy(input=logits, target=target, reduction='none'))
ce_loss = -(target * torch.log(torch.softmax(logits, 0)))
print(ce_loss)
print(ce_loss.sum())

tensor(11.5000)
tensor([9.0000, 0.0000, 2.5000, 0.0000])
tensor(11.5000)

Why it does so? Is there a bug?

Hi Phillip!

reduction for cross_entropy() does, in fact, work.

What’s going on is that reduction reduces across the batch dimension (as it does
for binary_cross_entropy()).

The likely source of confusion here is that your tensors have no batch dimension. (In
this case, cross_entropy() performs the computation for a single logits / target
pair, essentially the same as if your tensors had a batch dimension of one.) There is no
(non-trivial) batch to reduce across.

Try your test with a non-trivial batch dimension, say, for example, with logits and
target both having shape [3, 4], that is, with a batch dimension of 3 (and a class
dimension of 4).

Best.

K. Frank