Hello everyone,

I am doing reinforcement learning using a policy gradient algo.

However I have impossible actions, therefore I modify the logits of my impossible actions before using categorical.

When applying this mask I use a torch.where.

I have profiled the code and one of the biggest bottlenecks comes from this.

Do you have any optimization to do to make it faster?

Best regards

```
import torch
from torch.distributions import Categorical
class CategoricalMasked(Categorical):
LOGIT_INF_VALUE = -1E10
def __init__(self, mask, logits):
# dtype of mask torch.float32
# size example logit and mask (4096, 5)
self.mask = mask
self.mask = torch.gt(self.mask,0.5)
logits = torch.where(self.mask, logits,
torch.ones_like(logits) * torch.tensor(CategoricalMasked.LOGIT_INF_VALUE, device=logits.device)
super(CategoricalMasked, self).__init__(None, logits, None)
def entropy(self):
p_log_p = self.logits * self.probs
p_log_p = torch.where(self.mask, p_log_p, torch.tensor(0., device=self.logits.device)
return -p_log_p.sum(-1)
```