Please, someone could explain the math under the hood of Cross Entropy Loss in PyTorch?
I was performing some tests here and result of the Cross Entropy Loss in PyTorch doesn’t match with the result using the expression below:
I took some examples calculated using the expression above and executed it using the Cross Entropy Loss in PyTorch and the results were not the same.
I am trying this example here using Cross Entropy Loss from PyTorch:
Each pixel along the 3 channels corresponds to a probability distribution…there is a probability distribution for each position of the tensor…and the target has the classes for each distribution.
How can I know if this loss is beign computed correctly?
The issue is that pytorch’s CrossEntropyLoss doesn’t exactly match
the conventional definition of cross-entropy that you gave above.
Rather, it expects raw-score logits as it inputs, and, in effect, applies softmax() to the logits internally to convert them to probabilities.
(CrossEntropyLoss might better have been named CrossEntropyWithLogitsLoss.)
To check this, you could apply the logit function, log (p / (1 - p))
to convert your probs1 tensor, and then run that through CrossEntropyLoss.
Note that you are not using nn.CrossEntropyLoss correctly, as this criterion expects logits and will apply F.log_softmax internally, while probs already contains probabilities, as @KFrank explained.
So, let’s change the criterion to nn.NLLLoss and apply the torch.log manually.
This approach is just to demonstrate the formula and shouldn’t be used, as torch.log(torch.softmax()) is less numerically stable than F.log_softmax.
Also, the default reduction for the criteria in PyTorch will calculate the average over the observations, so lets use reduction='sum'.
Given that you’ll get:
criterion = nn.NLLLoss(reduction='sum')
loss = criterion(torch.log(probs1), target)
The manual approach from your formula would correspond to:
# Manual approach using your formula
one_hot = F.one_hot(target, num_classes = 3)
one_hot = one_hot.permute(0, 3, 1, 2)
ce = (one_hot * torch.log(probs1 + 1e-7))[one_hot.bool()]
ce = -1 * ce.sum()
While the manual approach from the PyTorch docs would give you:
# Using the formula from the docs
loss_manual = -1 * torch.log(probs1).gather(1, target.unsqueeze(1))
loss_manual = loss_manual.sum()
I was performing some tests using tensors with lower dimensions to ensure that the loss result is correct and, due to this, expand to tensors with higher dimensions and do not worry about the possibility that the loss value is wrong.
I was actually using nn.CrossEntropyLoss () in the wrong way, I apologize for that.
Now I understood how to use it !!