Cross Entropy Loss Math under the hood

Note that you are not using nn.CrossEntropyLoss correctly, as this criterion expects logits and will apply F.log_softmax internally, while probs already contains probabilities, as @KFrank explained.

So, let’s change the criterion to nn.NLLLoss and apply the torch.log manually.
This approach is just to demonstrate the formula and shouldn’t be used, as torch.log(torch.softmax()) is less numerically stable than F.log_softmax.

Also, the default reduction for the criteria in PyTorch will calculate the average over the observations, so lets use reduction='sum'.

Given that you’ll get:

criterion = nn.NLLLoss(reduction='sum')
loss = criterion(torch.log(probs1), target)

The manual approach from your formula would correspond to:

# Manual approach using your formula
one_hot = F.one_hot(target, num_classes = 3)
one_hot = one_hot.permute(0, 3, 1, 2)
ce = (one_hot * torch.log(probs1 + 1e-7))[one_hot.bool()]
ce = -1 * ce.sum()

While the manual approach from the PyTorch docs would give you:

# Using the formula from the docs
loss_manual = -1 * torch.log(probs1).gather(1, target.unsqueeze(1))
loss_manual = loss_manual.sum()

We should get the same results:

print(loss, ce, loss_manual)
> tensor(2.8824) tensor(2.8824) tensor(2.8824)

which looks correct.

3 Likes