logits = self.network(inputs)
# this does not work because nans, is there a way to do it without slicing?
# logits -= (1 - possible_actions) * float('inf')
dist = Categorical(...)(logits)
# this is clean but requires to recalculate the parameters, the logits dont update either
# dist.probs *= possible_actions
action = dist.sample()
action_log_probs = dist.log_probs(action)
I want to deny possible_actions==0 from being sampled
possible_actions is a vector like this: [0, 0, 1, 0, 1, 1,…]
in this case the first two actions should not be sampled but the third and fifth can.
I’m mostly interested in the best™ way to do it not just one that works because i could just set the logits to -inf but it ends up being messy.