Hi Zhihan!
The short, practical answer is because of what you typically do with
the log-softmax of the logits. You pass them into a loss function such
as nll_loss()
. (Doing this gives you, in effect, the cross-entropy loss.)
If you were to pass the raw logits into nll_loss()
you would get an
ill-behaved loss function that is unbounded below. That is, by, for
example, making the biases of your last linear layer (that produces the
logits) arbitrarily large, the logits will become arbitrarily large, and the
loss function will become arbitrarily “good,” that is large and negative.
But why is this?
As you have noticed, the log()
function is almost, but not quite the
inverse of the softmax()
function – the difference being a constant
(across classes for a given set of logits).
This constant is the difference between proper log-probabilities and
the “unnormalized log-probabilities” we call logits, and this is the
constant that becomes arbitrarily large when the nll_loss()
function
diverges to -inf
. Calculating log_softmax (logits)
normalizes this
constant away. (And, in some sense, that’s all it does, because
log_softmax (log_softmax (logits)) = log_softmax (logits)
.)
This constant is the log of the denominator in the formula for
softmax()
, namely log (sum_i {exp (logit_i)})
.
log_softmax()
has the further technical advantage: Calculating
log()
of exp()
in the normalization constant can become numerically
unstable. Pytorch’s log_softmax()
uses the “log-sum-exp trick” to
avoid this numerical instability.
From this perspective, the purpose of pytorch’s log_softmax()
function is to remove this normalization constant – in a numerically
stable way – from the raw, unnormalized logits we get from a linear
layer so we can pass them into a useful loss function.
Best.
K. Frank