Hi Zhengwei!

As noted by Mert:

That is, `1 - softmax (X)_i`

is

```
(sum_j (exp (X_j)) - exp (X_i)) / sum_j (exp (X_j) = (sum_(j != i) (exp (X_j)) / sum_j (exp (X_j)
```

This is also discussed in the stats.stackexchange thread you linked to.

You can implement your numerically-stable `log1m_softmax()`

“by hand”

along the lines discussed in that thread.

You can also use pytorch’s `logsumexp()`

to compute `log1m_softmax()`

without, in effect, reimplementing the log-sum-exp trick.

With a little manipulation, you can zero out the `i == j`

term in probability

space (i.e., in “exp” space) by replacing the term with `-inf`

(or a very

large negative number) in log space (i.e., the space of your original `X`

)

and then apply pytorch’s `logsumexp()`

to both the numerator and

denominator of the above expression for `1 - softmax (X)`

.

Consider:

```
>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> X = torch.tensor ([-1000, -1000, 1000.0])
>>>
>>> X.softmax (0) # softmax underflows to zero
tensor([0., 0., 1.])
>>> X.softmax (0).log() # log gives -inf
tensor([-inf, -inf, 0.])
>>> X.log_softmax (0) # numerically-stable version uses log-sum-exp trick internally
tensor([-2000., -2000., 0.])
>>>
>>> 1 - X.softmax (0) # (1 - softmax) underflows to zero
tensor([1., 1., 0.])
>>> (1 - X.softmax (0)).log() # log gives -inf
tensor([0., 0., -inf])
>>> # use pytorch's logsumexp() to implement numerically-stable version
>>> n = X.size (0)
>>> (X.unsqueeze (0).expand (n, n) + torch.diag (-torch.ones_like (X) / 0)).logsumexp (1) - X.logsumexp (0)
tensor([ 0.0000, 0.0000, -1999.3069])
>>>
>>> Y = torch.randn (5) # check against unstable version with random data (that doesn't underflow)
>>> Y
tensor([ 0.1915, 0.3306, 0.2306, 0.8936, -0.2044])
>>> (1 - Y.softmax (0)).log()
tensor([-0.1864, -0.2175, -0.1946, -0.4203, -0.1216])
>>> n = Y.size (0)
>>> (Y.unsqueeze (0).expand (n, n) + torch.diag (-torch.ones_like (Y) / 0)).logsumexp (1) - Y.logsumexp (0)
tensor([-0.1864, -0.2175, -0.1946, -0.4203, -0.1216])
```

Best.

K. Frank