Torch.where is too slow

Hello everyone,
I am doing reinforcement learning using a policy gradient algo.
However I have impossible actions, therefore I modify the logits of my impossible actions before using categorical.

When applying this mask I use a torch.where.
I have profiled the code and one of the biggest bottlenecks comes from this.
Do you have any optimization to do to make it faster?

Best regards


import torch
from torch.distributions import Categorical


class CategoricalMasked(Categorical):
    LOGIT_INF_VALUE = -1E10

    def __init__(self, mask, logits):
        # dtype of mask torch.float32
        # size example logit and mask (4096, 5)
        self.mask = mask
        self.mask = torch.gt(self.mask,0.5)

        logits = torch.where(self.mask, logits,
                             torch.ones_like(logits) * torch.tensor(CategoricalMasked.LOGIT_INF_VALUE, device=logits.device)
        super(CategoricalMasked, self).__init__(None, logits, None)

    def entropy(self):
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.mask, p_log_p, torch.tensor(0., device=self.logits.device)
        return -p_log_p.sum(-1)

Hi,

If you run on GPU, the following might be faster:

self.mask = self.mask.to(logits.dtype)
logits = self.mask * logits + (1 - self.mask) * CategoricalMasked.LOGIT_INF_VALUE
1 Like

Hi @albanD good news!
Why torch.where is slower than what you suggest ?

that shouldn’t be faster than: torch.lerp(LOGIT_INF_TENSOR, logits, self.mask)

p_log_p *= self.mask

1 Like

Thanks II’l try it ! Just one question why where is slower than lerp ? @googlebot

lerp should just fuse x1*w+x2*(1-w) expressions, like the one @albanD suggested, avoiding memory allocations

I don’t know whether this or torch.where is faster. The latter could be slower if it screws memory access patterns on cuda (but this is avoidable, I’d think)

1 Like

Ho I don’t have the answer for which one is faster. But I know that point-wise addition/multiplication kernels are heavily optimized and being worked on. While the where kernel is not scrutinized as much. But curious to know which one is faster actually compared to Alex’s proposal as well.

1 Like