Input for torch.nn.gumbel_softmax

Say I have a tensor named attn_weights of size [1,a], entries of which indicate the attention weights between the given query and |a| keys. I want to select the largest one using torch.nn.functional.gumbel_softmax.

I find docs about this function describe the parameter as logits - […, num_features] unnormalized log probabilities. I wonder whether should I take log of attn_weights before passing it into gumbel_softmax? And I find Wiki defines logit=lg(p/1-p), which is different from barely logrithm. I wonder which one should I pass to the function?

Further, I wonder how to choose tau in gumbel_softmax, any guidelines?

Based on the example code snippet in the docs it seems that “unnormalized log probabilities” would refer to raw logits, so no log or softmax should be applied, if I understand it correctly.

ok I got it. Thank you very much. I’ll check out the origin paper.