Justification for LogSoftmax being better than Log(Softmax)

There’ve been other questions on this forum asking about LogSoftmax vs Softmax. This question is more focused on why LogSoftmax is claimed to be better (both numerically and in terms of speed) than applying Log to the output of Softmax. The claim is mentioned in this doc page:

https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html

But, softmax by itself is actually numerically stable, and also uses the max trick for numerical stability (see link below). Therefore, it’s unclear to me why adding log to it makes it unstable.

I’d really appreciate if someone can help me understand this via some math / code.

Hi Zhihan!

First try some numerical experiments.

Use the double-precision version of the naive expression as your
assumed “correct” result. (The double-precision calculation will also
have the numerical overflow issue, but it won’t set in as soon.)

true_result = torch.log (torch.softmax (alpha * torch.tensor ([-1.0, 0.0, 1.0]).double(), dim = 0))

Then compare the single-precision results from

torch.log (torch.softmax (alpha * torch.tensor ([-1.0, 0.0, 1.0]), dim = 0))
# and
torch.log_softmax (alpha * torch.tensor ([-1.0, 0.0, 1.0]), dim = 0)

with one another and with your “true” result for increasing values of alpha,
say, alpha = 2, 5, 10, 20, 50, 100, ..., and see how the results behave.

Second, write two of your own versions of log_softmax(), one where you
just use the naive log (softmax()) approach, and a second where you
apply the “log-sum-exp trick” to the sum of exponentials in the denominator
of the formula for softmax().

Does the log-sum-exp trick significantly reduce the overflow issue? This
is pretty much how log_softmax() is implemented in pytorch.

Best.

K. Frank

1 Like

Hi KFrank!

Thanks a lot for the code example you gave, I gained a much better understanding of this issue. I’m sharing my results and interpretations below for you and others.

Try alpha=100 gives:

tensor([-200., -100.,    0.], dtype=torch.float64)  # log + softmax
tensor([-200., -100.,    0.])  # logsoftmax

Try alpha=1000 gives:

tensor([-inf, -inf, 0.])  # log + softmax
tensor([-2000., -1000.,     0.])  # logsoftmax

This immediately suggests to me that, if we apply log and softmax separately, when the output of softmax becomes very close to zero, then log would yield negative infinity.

For an even more succinct example, where the input of log is very close to zero (exp is just one way to achieve this):

torch.log(torch.exp(torch.tensor([-2000])))  # -inf

but if we adopt the log-softmax idea then the answer is clearly just -2000.

*Numerical overflow might not be relevant in this context though, since it’s ruled out by the max-trick in softmax implementation.

Merry Christmas,
Zhihan