How backward() is calculated in CrossEntropyLoss?

I have a simple Linear model and I need to calculate the loss for it. I applied two CrossEntropyLoss and NLLLoss but I want to understand how grads are calculated on these both methods.

On the output layer, I have 4 neurons which mean I am going to classify on 4 classes.

L1 = nn.Linear(2,4)

When I use CrossEntropyLoss I get grads for all the parameters:

L1.weight.grad
tensor([[ 0.1212,  0.2424],
        [ 0.1480,  0.2961],
        [ 0.1207,  0.2414],
        [-0.3899, -0.7798]])

But when I try to use NLLLoss, I just get grad for params of true class:

L1.weight.grad
tensor([[ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [-1., -2.]])

May I know how grads are calculated on both methods?

I cannot reproduce the issue and get the expected same results:

criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.NLLLoss()

L1 = nn.Linear(2,4)
x = torch.randn(1, 2)
target = torch.randint(0, 4, (1,))

# CrossEntropyLoss
out1 = L1(x)
loss1 = criterion1(out1, target)
print(loss1)
# tensor(1.6417, grad_fn=<NllLossBackward0>)

loss1.backward()
print(L1.weight.grad)
# tensor([[-0.7159,  0.1766],
#         [-0.5968,  0.1472],
#         [-0.1265,  0.0312],
#         [ 1.4391, -0.3549]])
L1.zero_grad()

# NLLLoss
out2 = L1(x)
loss2 = criterion2(F.log_softmax(out2, dim=1), target)
print(loss2)
# tensor(1.6417, grad_fn=<NllLossBackward0>)

loss2.backward()
print(L1.weight.grad)
# tensor([[-0.7159,  0.1766],
#         [-0.5968,  0.1472],
#         [-0.1265,  0.0312],
#         [ 1.4391, -0.3549]])

I guess you might be missing the F.log_softmax when using nn.NLLLoss?

1 Like

That is right. But actually what I wanted is that, how grads are calculated based on softmax and without it?

Let’s suppose that we have the output like [.1, .2, .3, .4] and the target is [3], could you tell me how the grads are really calculated based on CrossEntropyLoss and pure NLLLoss? I cannot understand how grads are calculated?
It can also be fine if tell me on your own example.
Thanks

You can see the manual approach of these loss functions here which would correspond to:

L1 = nn.Linear(2,4)
x = torch.randn(1, 2)
target = torch.randint(0, 4, (1,))

# PyTorch approach
criterion = nn.CrossEntropyLoss(reduction="sum")
ref = L1(x)
loss_ref = criterion(ref, target)
print(loss_ref)
# tensor(0.9717, grad_fn=<NllLossBackward0>)
loss_ref.backward()
print(L1.weight.grad)
# tensor([[ 0.0744, -0.0261],
#         [-0.6789,  0.2381],
#         [ 0.2678, -0.0939],
#         [ 0.3368, -0.1181]])
L1.zero_grad()

# manual
out = L1(x)
#probs = torch.softmax(out, dim=1)
probs = torch.exp(out) / torch.sum(torch.exp(out), dim=1)

# Manual approach using your formula
one_hot = F.one_hot(target, num_classes = 4)
ce = (one_hot * torch.log(probs + 1e-7))[one_hot.bool()]
ce = -1 * ce.sum()
print(ce)
# tensor(0.9717, grad_fn=<MulBackward0>)
ce.backward()
print(L1.weight.grad)
# tensor([[ 0.0744, -0.0261],
#         [-0.6789,  0.2381],
#         [ 0.2678, -0.0939],
#         [ 0.3368, -0.1181]])
L1.zero_grad()

The second approach shows the manual calls and you should be able to recompute the backward pass based on the used operations (torch.exp, torch.sum, indexing etc.).

1 Like

Hi Mahdi!

Before addressing a possible point of confusion, let me emphasize that you
should not feed the output of a Linear directly into NLLLoss.

The output of a Linear will typically include positive numbers, which, when
interpreted as log-probabilities, correspond to invalid “probabilities” that are
greater than one. If you permit “probabilities” greater than one, NLLLoss
will return negative values, will be unbounded below, and your training will
diverge.

This is to be expected (depending on the details of your use case).

NLLLoss plucks out the log-probability of the “true” class (as specified
by the passed-in target), and doesn’t depend on the non-true-class
probabilities. So those gradients are zero.

Consider:

>>> import torch
>>> torch.__version__
'1.13.0'
>>> _ = torch.manual_seed (2022)
>>> log_probs = torch.rand (1, 5, requires_grad = True)
>>> targ = torch.tensor ([2])   # "true" class is 2
>>> loss = torch.nn.NLLLoss() (log_probs, targ)
>>> loss
tensor(-0.7588, grad_fn=<NllLossBackward0>)
>>> loss.backward()
>>> log_probs.grad   # only non-zero for "true" class
tensor([[ 0.,  0., -1.,  0.,  0.]])
>>> with torch.no_grad():   # change values for some "non-true" classes
...     log_probs[0, 0] = 66.6
...     log_probs[0, 4] = 99.9
...
>>> torch.nn.NLLLoss() (log_probs, targ)   # loss doesn't change
tensor(-0.7588, grad_fn=<NllLossBackward0>)

When you pass the output of Linear through log_softmax() (or softmax(),
for that matter), it mixes the classes together so that the “true”-class value
(that NLLLoss plucks out) depends on all of the outputs of Linear and you
get non-zero gradients for all elements of weight (and bias).

Best.

K. Frank

1 Like

What my problem was, it was about derivative of softmax and now I understand how to calculate and then apply it to network.