Say I have a tensor named attn_weights
of size [1,a], entries of which indicate the attention weights between the given query and |a| keys. I want to select the largest one using torch.nn.functional.gumbel_softmax
.
I find docs about this function describe the parameter as logits - […, num_features] unnormalized log probabilities. I wonder whether should I take log
of attn_weights
before passing it into gumbel_softmax
? And I find Wiki defines logit=lg(p/1-p)
, which is different from barely logrithm. I wonder which one should I pass to the function?
Further, I wonder how to choose tau
in gumbel_softmax
, any guidelines?