How to achieve fast masking?

How can I achieve fast masking as follows:

import torch
q = torch.Tensor([2, 3, 6, 2])
x = torch.randn(4, 10, 3, 3)
for i in range(4):
    x[i, int(q[i]):] = 0
print(x)

Hi Bin!

You can get rid of the loop by multiplying x inplace with a boolean
mask tensor:

>>> torch.__version__
'1.10.2'
>>> _ = torch.manual_seed (2022)
>>> q = torch.Tensor([2, 3, 6, 2])
>>> x = torch.randn(4, 10, 3, 3)
>>> xB = x.clone()
>>> for i in range(4):
...     x[i, int(q[i]):] = 0
...
>>> xB *= (torch.arange (10).unsqueeze (0).expand (4, 10) < q.unsqueeze (1)).unsqueeze (-1).unsqueeze (-1)
>>> torch.equal (x, xB)
True

Best.

K. Frank

1 Like