How to calculate log(1 - softmax(X)) numerically stably

Hi all,

I have a multiclass classification problem and there are some inter-class relationship. Then a modified version of Cross-Entropy Loss Function is used.

image
, where P is the probability and M is the label.

For short, in addtion to log_softmax(), I need to implement log(1 - softmax(X)), let’s call it log1m_softmax(). However, log1m_softmax() is numerically unstable even with LogSumExp trick.

I believe that I understand the implemention of numerically stable log_softmax() well, which has been well explained here.

A related disscusion can be found here, but it leads to another numerical problem of log(1 - exp(x)) when x underflowes to zero.

Please help me to find a stable log1m_softmax(), thank you!

Would this help?

Not exactly, because when x=0, the result of this log1mexp() is still -Inf.

This is real case when one value xi in X dominates LogSumExp(X), where xi - LogSumExp(X) = 0 with limited precision of float.

Can you also not get away with using the original trick for log softmax, trick of subtracting max value?

In the original trick of logsoftmax, the denomitor was numerically stable while summing over all exponential values (with max value subtracted in the power)

The denominator for your case is the same, the numerator will be the same except for the one term which is absent. So, it seems that numerator is also stable ??

Let’s consider the case where X = [-1000, -1000, 1000] and c = X.max() = 1000.

After subtracting the maxium, X - c = [-2000, -2000, 0] and exp(X - c) = [0.0, 0.0, 1].

Therefore, LSE = LogSumExp(X - c) = 0, and LSE - (X - c) = [2000, 2000, 0].

This lead to log(LSE - (X - c)) = [7.6, 7.6, -Inf].

I think the denominator you mentioned is LSE, and after we subtract the dominant term from it, it gets to zero.

Hi Zhengwei!

As noted by Mert:

That is, 1 - softmax (X)_i is

(sum_j (exp (X_j)) - exp (X_i)) / sum_j (exp (X_j) = (sum_(j != i) (exp (X_j)) / sum_j (exp (X_j)

This is also discussed in the stats.stackexchange thread you linked to.
You can implement your numerically-stable log1m_softmax() “by hand”
along the lines discussed in that thread.

You can also use pytorch’s logsumexp() to compute log1m_softmax()
without, in effect, reimplementing the log-sum-exp trick.

With a little manipulation, you can zero out the i == j term in probability
space (i.e., in “exp” space) by replacing the term with -inf (or a very
large negative number) in log space (i.e., the space of your original X)
and then apply pytorch’s logsumexp() to both the numerator and
denominator of the above expression for 1 - softmax (X).

Consider:

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> X = torch.tensor ([-1000, -1000, 1000.0])
>>>
>>> X.softmax (0)               # softmax underflows to zero
tensor([0., 0., 1.])
>>> X.softmax (0).log()         # log gives -inf
tensor([-inf, -inf, 0.])
>>> X.log_softmax (0)           # numerically-stable version uses log-sum-exp trick internally
tensor([-2000., -2000.,     0.])
>>>
>>> 1 - X.softmax (0)           # (1 - softmax) underflows to zero
tensor([1., 1., 0.])
>>> (1 - X.softmax (0)).log()   # log gives -inf
tensor([0., 0., -inf])
>>> # use pytorch's logsumexp() to implement numerically-stable version
>>> n = X.size (0)
>>> (X.unsqueeze (0).expand (n, n) + torch.diag (-torch.ones_like (X) / 0)).logsumexp (1) - X.logsumexp (0)
tensor([    0.0000,     0.0000, -1999.3069])
>>>
>>> Y = torch.randn (5)         # check against unstable version with random data (that doesn't underflow)
>>> Y
tensor([ 0.1915,  0.3306,  0.2306,  0.8936, -0.2044])
>>> (1 - Y.softmax (0)).log()
tensor([-0.1864, -0.2175, -0.1946, -0.4203, -0.1216])
>>> n = Y.size (0)
>>> (Y.unsqueeze (0).expand (n, n) + torch.diag (-torch.ones_like (Y) / 0)).logsumexp (1) - Y.logsumexp (0)
tensor([-0.1864, -0.2175, -0.1946, -0.4203, -0.1216])

Best.

K. Frank

2 Likes

Hi Frank!

Thanks a lot for your answer code and explaination.

Adding -inf to the term j== i or making xi = - inf can remove the term perfectly in probability space (i.e., exp(-inf) = 0), and then combine this with logsumexp(), a numerically stable log1m_softmax() can be achived. This is the solution to this question!

However, I cannot find the solution in the stackexchange thread in all answers, which don’t introduce this -inf trick and lead to other unstable problem. Could you please point out the one where you found the answer?

Hi Zhengwei!

I didn’t mean to suggest that the stackexchange thread proposes using
-inf to “zero out” the extra term in log-probability space. I chose to go
this route in order to use pytorch’s logsumexp() and avoid implementing
a version of the log-sum-exp trick that works for the terms that show up
in the expression for log (1 - softmax (X)).

Best.

K. Frank

Sorry for a mistake here, it should be:

Therefore, sum of exp(X - c) is SE = 1, and SE - exp(X - c) = [1, 1, 0].

This leads to log(SE - exp(X - c)) = [0, 0, -inf], so although the denominator is stable with LSE, the numerator is unstable on dominant term.