Regarding different loss values with cross entropy loss during single tensors vs probability tensors

So I was going through the documentation of the cross entropy loss and I noticed that while taking the probabilities they have performed softmax , but softmax is internally performed again in cross entropy loss so confused why its is mentioned again in the Example of target with class probabilities

>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()

Moreover in this following example shouldnt the loss be the same in the first two cases

import torch.nn as nn
import torch
from torch.nn import Softmax as softmax
target_tensor = torch.tensor([1, 1, 0, 1])
target_tensor_prob = torch.tensor([[1.0,  0.0],
        [1.0, 0.0],
        [0.0,  1.0],
        [1.0, 0.0]])
input_tensor = torch.randn(4, 2)

loss = nn.CrossEntropyLoss(reduction='mean')
with torch.no_grad():
    output = loss(input_tensor, target_tensor )
    print(output)
    output2 = loss(input_tensor, target_tensor_prob)
    print(output2)
    output3 = loss(input_tensor, target_tensor_prob.softmax(dim=1))
    print(output3)

It gives the following ouput
tensor(0.7072)
tensor(0.9613)
tensor(0.8930)

According to my understanding the first 2 should be the same.

Hi Akshay!

When using CrossEntropyLoss with probabilistic (sometimes called “soft”)
targets, the input (the predictions) are unnormalized log-probabilities, but
the target are probabilities. Generating target by applying softmax() to
a tensor is a convenient way to get legitimate probabilities for the sake of
the example.

CrossEntropyLoss does apply log_softmax() internally to its input, but
not to its target.

No, you have your target_tensor_prob backward. Its first row, for example,
should be [0.0, 1.0], as this would mean a probability of zero of being in
your “0-class” and a probability of one of being in your “1-class.” This is what
properly corresponds to the integer class label of 1 that appears as the first
element of your target_tensor.

Best.

K. Frank

1 Like