It confused me a lot that why the logtis could be unnormalized. In the equation the torch function uses
softmax((log p_i - log (-log e_i)) / t)
where log(-log e_i)) is the gumbel noise, t is the temperature, p_i is the probability. if you want to use the unnormalized logits o_i to replace log p_i in the equation may not hold unless you have
I do not have an opinion about whether or not pytorch’s gumbel_softmax()
is correct.
However, as a practical matter, if you are concerned about this issue, you
may pass the output of your “arbitrary network” through log_softmax()
before passing it to gumbel_softmax().
One way of looking at this is that the only thing that log_softmax() does
is convert unnormalized log-probabilities to normalized log-probabilities.
(Closely related is that if you start with normalized log-probabilities, then
the inverse of softmax() is log().)
As an aside, log-probabilities are not logits (and pytorch should not use logits as the name of the log-probabilities argument to gumbel_softmax()).
I’ve mixed up the terminology in some of my posts, but it’s worth trying to get
it right. Logits are log-odds-ratios – log (p / (1 - p)) (and are converted
to probabilities with the logistic sigmoid() function). Logits play the same
role with BCEWtthLogitsLoss that (unnormalized) log- probabilities play
with CrossEntropyLoss.