Gumbel_softmax function logits?

The link to PyTorch implementation

Both in the code and in the docs, the logits argument for the function is annotated as “unnormalized log probabilities”. If this is intended to mean the raw scores before any softmax layer, then I have a hard time understanding why this should work at all. Both in the RelaxedOneHotCategorical distribution implementation and the original Jang’s paper, the logits are clearly normalized. However from Jang’s notebook implementation which is where PyTorch got its gumbel softmax from, the input to that function is straight out from a linear layer, meaning it is not normalized. Does anyone have any insight into this?

1 Like

Hi~ I have the same question.
Did you fix it?

Okay, I fixed it.
It does not matter. Because there is another softmax after you add the logits with Gumbel noise, so normalized logits and unnormalized logits result in the same probabilities.