How does nn.CrossEntropyLoss() work under the hood?

OK so I figured it out. A huge shout out to this thread. Basically the predictions that you get from the model are run through a log softmax function. Then the results from that are used to get the values associated with corresponding class index. And then finally, those values are summed and divided by the number of classes or the shape of the target tensor. :slightly_smiling_face:

And here is the code:

preds = torch.randn(3, 5)
print("preds:")
print(preds)

targets = torch.randint(0, 5, (1, 3)).squeeze(0)
print()
print("targets:")
print(targets)

loss_func = nn.CrossEntropyLoss()
loss = loss_func(preds, targets)
print()
print("PyTorch loss:", loss)

preds_log_softmax = F.log_softmax(preds, 1)
print()
print("preds_log_softmax:")
print(preds_log_softmax)

my_loss = 0
for i in range(targets.shape[0]):
    my_loss += preds_log_softmax[i, targets[i].item()]
    
my_loss *= -1
my_loss /= targets.shape[-1] 
print()
print("my loss:", my_loss)