Cross entropy in batches

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 3, 5, requires_grad=True)
target = torch.empty(3, 3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()

 Expected target size [3, 5], got [3, 3]

in this example, my batch size is 3 and each contains 3 categories each category contains a probability distribution of 5. so the shape is (3,3,5)

My target would be (batch size, categories index). In this example is (3,3)
Unless I create a one-hot vector with each category’s index, converted it into (3,3,5). Otherwise, it won’t work but this is not a good approach.

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
From the documentation, you can see that the second dimension belongs to the number of classes (C).
The below code works:

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, 3, requires_grad=True)
target = torch.empty(3, 3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
1 Like