Hi Zhengwei!
As noted by Mert:
That is, 1 - softmax (X)_i
is
(sum_j (exp (X_j)) - exp (X_i)) / sum_j (exp (X_j) = (sum_(j != i) (exp (X_j)) / sum_j (exp (X_j)
This is also discussed in the stats.stackexchange thread you linked to.
You can implement your numerically-stable log1m_softmax()
“by hand”
along the lines discussed in that thread.
You can also use pytorch’s logsumexp()
to compute log1m_softmax()
without, in effect, reimplementing the log-sum-exp trick.
With a little manipulation, you can zero out the i == j
term in probability
space (i.e., in “exp” space) by replacing the term with -inf
(or a very
large negative number) in log space (i.e., the space of your original X
)
and then apply pytorch’s logsumexp()
to both the numerator and
denominator of the above expression for 1 - softmax (X)
.
Consider:
>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> X = torch.tensor ([-1000, -1000, 1000.0])
>>>
>>> X.softmax (0) # softmax underflows to zero
tensor([0., 0., 1.])
>>> X.softmax (0).log() # log gives -inf
tensor([-inf, -inf, 0.])
>>> X.log_softmax (0) # numerically-stable version uses log-sum-exp trick internally
tensor([-2000., -2000., 0.])
>>>
>>> 1 - X.softmax (0) # (1 - softmax) underflows to zero
tensor([1., 1., 0.])
>>> (1 - X.softmax (0)).log() # log gives -inf
tensor([0., 0., -inf])
>>> # use pytorch's logsumexp() to implement numerically-stable version
>>> n = X.size (0)
>>> (X.unsqueeze (0).expand (n, n) + torch.diag (-torch.ones_like (X) / 0)).logsumexp (1) - X.logsumexp (0)
tensor([ 0.0000, 0.0000, -1999.3069])
>>>
>>> Y = torch.randn (5) # check against unstable version with random data (that doesn't underflow)
>>> Y
tensor([ 0.1915, 0.3306, 0.2306, 0.8936, -0.2044])
>>> (1 - Y.softmax (0)).log()
tensor([-0.1864, -0.2175, -0.1946, -0.4203, -0.1216])
>>> n = Y.size (0)
>>> (Y.unsqueeze (0).expand (n, n) + torch.diag (-torch.ones_like (Y) / 0)).logsumexp (1) - Y.logsumexp (0)
tensor([-0.1864, -0.2175, -0.1946, -0.4203, -0.1216])
Best.
K. Frank