Does nn.CrossEntropyLoss internally build a one-hot encoding of the target?

I’m trying to build up an understanding of how nn.CrossEntropyLoss works. I see that it combines log_softmax and NLLLoss but am trying to work out exactly how.

Suppose a model outputs logits of [1.761, 0.590, 3.157] and the ‘target’ is the third index (2). So if I manually assembled a one-hot encoding vector, the target would look like [0, 0, 1].

When you use nn.CrossEntropyLoss, according to the docs, you can pass ‘class indices in the range [0,C) where C is the number of classes’ as the target. So in the above example, I would pass 2.

According to this Stack Overflow answer the nn.CrossEntropyLoss will simply ‘index into the output probability’ with the target (i.e. 2 indexes into the logits and selects 3.157) and then that number alone is used to compute the loss.

Is that true? Does this loss function simply index into the logits and work with that number? Or does it internally construct a one-hot encoding vector with a 1 in the target index position, and 0s elsewhere?

It seems to me that the latter is a much better option for computing loss because not only are you comparing the correct target to its predicted score, but you’re also comparing all the incorrect scores to 0. Surely that helps the model to learn faster? Or is my intuition incorrect here?

Thanks in advance.

Yes, you can directly index the logit as seen in this small example which calculates the loss manually using an explicit but slow approach.

No, as it would be wasteful to create the one-hot encoded target first and multiply it with the full logit tensor. The result would be equal to indexing just the logit where the one-hot encoded target would have its 1 value. The limitation is of course that indexing won’t work with “soft-targets”, i.e. target tensors which contain probabilities and not class indices. For this use case nn.CrossEntropyLoss accepts a FloatTensor with probabilities in newer PyTorch releases.

That’s not true, as you would multiply the logits of the uninteresting classes with a zero, so these values are not changing the loss at all.

1 Like

Of course, there I was thinking my intuition was dodgy but actually I just wasn’t thinking through the process of 0 * any probability = 0. Indexing in makes sense; as you say it would be pointless to build this one-hot encoding anyway. Thanks!