Hello everyone,
I am doing reinforcement learning using a policy gradient algo.
However I have impossible actions, therefore I modify the logits of my impossible actions before using categorical.
When applying this mask I use a torch.where.
I have profiled the code and one of the biggest bottlenecks comes from this.
Do you have any optimization to do to make it faster?
Best regards
import torch
from torch.distributions import Categorical
class CategoricalMasked(Categorical):
LOGIT_INF_VALUE = -1E10
def __init__(self, mask, logits):
# dtype of mask torch.float32
# size example logit and mask (4096, 5)
self.mask = mask
self.mask = torch.gt(self.mask,0.5)
logits = torch.where(self.mask, logits,
torch.ones_like(logits) * torch.tensor(CategoricalMasked.LOGIT_INF_VALUE, device=logits.device)
super(CategoricalMasked, self).__init__(None, logits, None)
def entropy(self):
p_log_p = self.logits * self.probs
p_log_p = torch.where(self.mask, p_log_p, torch.tensor(0., device=self.logits.device)
return -p_log_p.sum(-1)