Hi there,
I would like to know how nn.CrossEntropyLoss() works under the hood. I implemented 2 toy examples one where there is only 1 class/target and a prediction of shape (1, 3) where 1 is the batch size and 3 is the number of classes. I followed this example which shows how to do this when dealing with the above mentioned variables. However I would like to know how the loss is computed when there are multiple targets and the prediction is in the form of say (2, 3) (batch size = 2, number of classes = 3). How is the loss calculated then?
Following the instructions in the link above I did this:
torch.manual_seed(0)
target = torch.randint(0, 3, (1, 1)).squeeze(0)
preds = torch.randn(1, 3)
print("target:")
print(target)
print()
print("preds:")
print(preds)
loss_func = nn.CrossEntropyLoss()
loss = loss_func(preds, target)
print()
print("loss:", loss.item())
print()
print("my loss:---------------------------.")
print(-preds + torch.log(torch.sum(torch.exp(preds))))
But I wasn’t able to understand how PyTorch calculates the loss with the following variables:
torch.manual_seed(0)
target = torch.randint(0, 3, (1, 2)).squeeze(0)
preds = torch.randn(2, 3)
print("target:")
print(target)
print()
print("preds:")
print(preds)
loss = loss_func(preds, target)
print()
print("loss:", loss.item())
Any help is appreciated thanks