Function `gumbel_softmax`'s documentation is misleading

In the documentation of gumbel_softmax, the first parameter logits

logits: `[..., num_features]` unnormalized log probabilities

It confused me a lot that why the logtis could be unnormalized. In the equation the torch function uses

softmax((log p_i - log (-log e_i)) / t)

where log(-log e_i)) is the gumbel noise, t is the temperature, p_i is the probability. if you want to use the unnormalized logits o_i to replace log p_i in the equation may not hold unless you have

\sum_i exp(o_i) = 1

which is not guaranteed by an arbitrary network.

This thread also mentioned this issue

Hi Weikang!

I do not have an opinion about whether or not pytorch’s gumbel_softmax()
is correct.

However, as a practical matter, if you are concerned about this issue, you
may pass the output of your “arbitrary network” through log_softmax()
before passing it to gumbel_softmax().

One way of looking at this is that the only thing that log_softmax() does
is convert unnormalized log-probabilities to normalized log-probabilities.

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.1
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> unnormalized_log_probabilities = torch.randn (5)
>>> unnormalized_log_probabilities
tensor([-1.2075,  0.5493, -0.3856,  0.6910, -0.7424])
>>>
>>> probabilities = unnormalized_log_probabilities.softmax (0)
>>> probabilities
tensor([0.0577, 0.3342, 0.1312, 0.3851, 0.0918])
>>> probabilities.sum()
tensor(1.)
>>>
>>> normalized_log_probabilities = unnormalized_log_probabilities.log_softmax (0)
>>> normalized_log_probabilities
tensor([-2.8528, -1.0960, -2.0309, -0.9544, -2.3877])
>>> normalized_log_probabilities.exp()
tensor([0.0577, 0.3342, 0.1312, 0.3851, 0.0918])
>>> normalized_log_probabilities.exp().sum()
tensor(1.)
>>>
>>> torch.equal (normalized_log_probabilities.log_softmax (0), normalized_log_probabilities)
True

(Closely related is that if you start with normalized log-probabilities, then
the inverse of softmax() is log().)

As an aside, log-probabilities are not logits (and pytorch should not use
logits as the name of the log-probabilities argument to gumbel_softmax()).
I’ve mixed up the terminology in some of my posts, but it’s worth trying to get
it right. Logits are log-odds-ratios – log (p / (1 - p)) (and are converted
to probabilities with the logistic sigmoid() function). Logits play the same
role with BCEWtthLogitsLoss that (unnormalized) log- probabilities play
with CrossEntropyLoss.

Best.

K. Frank