How does nn.CrossEntropyLoss() work under the hood?

Hi there,

I would like to know how nn.CrossEntropyLoss() works under the hood. I implemented 2 toy examples one where there is only 1 class/target and a prediction of shape (1, 3) where 1 is the batch size and 3 is the number of classes. I followed this example which shows how to do this when dealing with the above mentioned variables. However I would like to know how the loss is computed when there are multiple targets and the prediction is in the form of say (2, 3) (batch size = 2, number of classes = 3). How is the loss calculated then?

Following the instructions in the link above I did this:

torch.manual_seed(0)

target = torch.randint(0, 3, (1, 1)).squeeze(0)
preds = torch.randn(1, 3)

print("target:")
print(target)

print()
print("preds:")
print(preds)

loss_func = nn.CrossEntropyLoss()
loss = loss_func(preds, target)

print()
print("loss:", loss.item())

print()
print("my loss:---------------------------.")
print(-preds + torch.log(torch.sum(torch.exp(preds))))

But I wasn’t able to understand how PyTorch calculates the loss with the following variables:

torch.manual_seed(0)

target = torch.randint(0, 3, (1, 2)).squeeze(0)
preds = torch.randn(2, 3)

print("target:")
print(target)

print()
print("preds:")
print(preds)

loss = loss_func(preds, target)

print()
print("loss:", loss.item())

Any help is appreciated thanks :slightly_smiling_face:

OK so I figured it out. A huge shout out to this thread. Basically the predictions that you get from the model are run through a log softmax function. Then the results from that are used to get the values associated with corresponding class index. And then finally, those values are summed and divided by the number of classes or the shape of the target tensor. :slightly_smiling_face:

And here is the code:

preds = torch.randn(3, 5)
print("preds:")
print(preds)

targets = torch.randint(0, 5, (1, 3)).squeeze(0)
print()
print("targets:")
print(targets)

loss_func = nn.CrossEntropyLoss()
loss = loss_func(preds, targets)
print()
print("PyTorch loss:", loss)

preds_log_softmax = F.log_softmax(preds, 1)
print()
print("preds_log_softmax:")
print(preds_log_softmax)

my_loss = 0
for i in range(targets.shape[0]):
    my_loss += preds_log_softmax[i, targets[i].item()]
    
my_loss *= -1
my_loss /= targets.shape[-1] 
print()
print("my loss:", my_loss)