Loss nll - formula in documentation

Greetings,

I am a bit confused by the documented formula for the negative log-likelihood loss:

What are those x_(n,y_n) ? They say x is the input, but the loss is not calculated from the input. A loss should be calculated from the output and the target, should it not?

Best,
PiF

edit: is the following assumption correct?
x_(n,y_n) is the entry of the network output vector that corresponds to the probability of the target y_n, where n denotes the sample.

Hi Physics!

Yes, the terminology in the documentation is somewhat unfortunate.

The argument to the loss functions that pytorch refers to as input is
indeed the “output” of the network (or derived from it). I prefer to call
this the “prediction.” The loss function then compares the “prediction”
with the “ground truth” label, which is the argument that pytorch calls
the target.

i guess one can think of the “prediction” being the “input” to the loss
function (but so is the target).

So “output” of model = “prediction” = “input” to loss function. But using
the term input in the loss-function documentation has always seemed
a bit confusing to me, as it suggests to me the input to the model, rather
than the output of the model. But I’ve made my peace with the pytorch
terminology.

Not quite. What you denote “x_(,y_n)” is the predicted log-probability
of the sample corresponding to target y_n (rather than the probability).

So, typically, the output of the final Linear layer in the model are the
predicted raw-score logits. When passed through LogSoftmax, you
get the predicted log-probabilities that you then pass into NLLLoss.
(If you pass the logits through Softmax, you get probabilities, but it’s
numerically more stable to work with the log-probabilities.)

(For convenience, pytorch’s CrossEntropyLoss combines LogSoftmax
together with NLLLoss so you can pass in the logits directly without
passing them through a separate LogSoftmax function.)

Best.

K. Frank

1 Like

Thank you very much K. Frank, for this detailed explanation!
I guess calling the network output “input” makes sense from a programmer’s perspective, since it is an input to a function.

It is all clear now.
PiF