The link to PyTorch implementation
Both in the code and in the docs, the logits argument for the function is annotated as “unnormalized log probabilities”. If this is intended to mean the raw scores before any softmax layer, then I have a hard time understanding why this should work at all. Both in the RelaxedOneHotCategorical distribution implementation and the original Jang’s paper, the logits are clearly normalized. However from Jang’s notebook implementation which is where PyTorch got its gumbel softmax from, the input to that function is straight out from a linear layer, meaning it is not normalized. Does anyone have any insight into this?