How backward() is calculated in CrossEntropyLoss?

You can see the manual approach of these loss functions here which would correspond to:

L1 = nn.Linear(2,4)
x = torch.randn(1, 2)
target = torch.randint(0, 4, (1,))

# PyTorch approach
criterion = nn.CrossEntropyLoss(reduction="sum")
ref = L1(x)
loss_ref = criterion(ref, target)
print(loss_ref)
# tensor(0.9717, grad_fn=<NllLossBackward0>)
loss_ref.backward()
print(L1.weight.grad)
# tensor([[ 0.0744, -0.0261],
#         [-0.6789,  0.2381],
#         [ 0.2678, -0.0939],
#         [ 0.3368, -0.1181]])
L1.zero_grad()

# manual
out = L1(x)
#probs = torch.softmax(out, dim=1)
probs = torch.exp(out) / torch.sum(torch.exp(out), dim=1)

# Manual approach using your formula
one_hot = F.one_hot(target, num_classes = 4)
ce = (one_hot * torch.log(probs + 1e-7))[one_hot.bool()]
ce = -1 * ce.sum()
print(ce)
# tensor(0.9717, grad_fn=<MulBackward0>)
ce.backward()
print(L1.weight.grad)
# tensor([[ 0.0744, -0.0261],
#         [-0.6789,  0.2381],
#         [ 0.2678, -0.0939],
#         [ 0.3368, -0.1181]])
L1.zero_grad()

The second approach shows the manual calls and you should be able to recompute the backward pass based on the used operations (torch.exp, torch.sum, indexing etc.).

1 Like