Sampling from Categorical distribution based on values

Hello peeps, I am trying to implement the Reinforce algorithm for sequence-to-sequence modeling. For this, I need to get a baseline [greedy] distribution and a sampled distribution with probabilities.

When it comes to the sampled distribution, at each decoding step I am doing the following to sample a token and its probability:

token_1 = torch.gather(log_probs, 1, indexes)
multi_dist = Categorical(token_1)
x_t = multi_dist.sample()
log_prob = multi_dist.log_prob(x_t)

However, my model has a large action space, and I would like to sample a token from a limited space. Is there any way I can either sample a token based on their probabilities/values or mask the tensor I am passing to Categorical?

1 Like

Did you try passing logits = -inf to mask tokens you don’t want to sample?