How to sample k times by gumbel softmax


I am trying to sample k elements from a categorical distribution in a differential way, and i notice that F.gumbel_softmax(logit, tau=1, hard=True) can return a one-hot tensor, but how can i sample t times using the gumbel sofmax, like topk function in pytorch.

i don’t know but maybe you can interpolate logit with nearest neighbor mode and then give it to gumbel_softmax.
this way you should not have any gradient issue.