Replacing the for loop with Pytorch's tensor operation(s)

Hi all, I have a tensor with integers in the range [0, N-1] and I need to create N masks, one for each integer value, i.e. in the 0’s mask all values of 0 should be True and the rest False, and so on. So far I could come up with a for loop on the range, but that drastically increases the runtime. I’d be grateful if someone could suggest equivalent Pytorch tensor operation(s) instead.

Hi Melike!

Try:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> N = 10
>>> t = torch.randint (N, (2, 3, 5))
>>> i = torch.ones ([N] + list (t.shape)).cumsum (0) - 1
>>> masks = t == i
>>> t
tensor([[[4, 9, 3, 2, 5],
         [6, 2, 0, 5, 0],
         [1, 9, 7, 7, 4]],

        [[8, 0, 0, 2, 4],
         [1, 7, 5, 7, 0],
         [1, 6, 9, 8, 6]]])
>>> masks[0]
tensor([[[False, False, False, False, False],
         [False, False,  True, False,  True],
         [False, False, False, False, False]],

        [[False,  True,  True, False, False],
         [False, False, False, False,  True],
         [False, False, False, False, False]]])
>>> masks[4]
tensor([[[ True, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False, False],
         [False, False, False, False, False]]])

Best.

K. Frank

1 Like