Cleanly update a categorical distribution

logits = self.network(inputs)

# this does not work because nans, is there a way to do it without slicing? 
# logits -= (1 - possible_actions) * float('inf')

dist = Categorical(...)(logits)

# this is clean but requires to recalculate the parameters, the logits dont update either
# dist.probs *= possible_actions

action = dist.sample()
action_log_probs = dist.log_probs(action)

I want to deny possible_actions==0 from being sampled

possible_actions is a vector like this: [0, 0, 1, 0, 1, 1,…]
in this case the first two actions should not be sampled but the third and fifth can.

I’m mostly interested in the best™ way to do it not just one that works because i could just set the logits to -inf but it ends up being messy.