I am using a “one hot” implementation of Cross Entropy Loss, meaning the target is also a vector and not an index, I need this kind of implementation for further research.
When I compare pytorch nn.CrossEntropyLoss (when giving target as an index instead of “one hot”) to my implementation,I can’t learn anything, I suspect it has to do with vanishing gradients.
Both logits and targets are of type float tensor.
I’m not sure exactly what you are asking, but, in isolation, your
cross-entropy implementation looks mathematically correct
to me. However, it would appear that your loss returns a
vector of length equal to the batch size. (It’s not completely
clear where – or whether – the batch size occurs in your loss.)
So you might need to sum your loss over the batch, but without
seeing how you use your loss, it’s hard to tell for sure whether
this is the problem.
Or perhaps your issue lies somewhere else in what you haven’t
Comparing nn.CrossEntropyLoss with your version, we see
that they give the same result, provided we sum your loss over
the batch to get a scalar:
logits = torch.randn (2, 5)
targ = torch.tensor ([, ], dtype = torch.long)
targ1hot = torch.zeros (logits.shape).scatter (1, targ, 1.0)
# loss1a is your "one-hot" version of CrossEntropyLoss
# it gives a loss value for each sample in the batch
loss1a = torch.sum(- targ1hot * torch.nn.functional.log_softmax(logits, -1), -1)
# loss1b is your version summed over the batch
loss1b = loss1a.sum()
# loss1c uses torch.sum() to sum directory over both the classes and the batch
loss1c = torch.sum(- targ1hot * torch.nn.functional.log_softmax(logits, -1))
# loss2 is the pytorch's class-index CrossEntropyLoss
loss2 = torch.nn.CrossEntropyLoss (reduction = 'sum')(logits, targ[:,0])
A side note: In general, for numerical reasons, log_softmax()
is preferred over log (softmax()).
That is odd, and suggests that your problem is hiding somewhere
in “all the rest of the code.” Could you try to trim things
down to a small, simple (and runnable) example that displays
After digging into the original nn.CrossEntropyLoss implementation.
The problem was summing the loss on the batch instead of averaging it. It was surprising to see that I get such different results once summing since the batch size I am using is very small (one sentence which is ~30 words).