Create a mask tensor using index

I’m trying to create *_key_padding_mask for torch.nn.Transformer. So I want to create a mask using the length of sentence data in the batch.
Assuming number of time steps t=7 t=8

Edit: changed value of t

batch_size = 5
idx = torch.tensor([3,4,2,1,6])


mask = torch.tensor([[0,0,0,1,1,1,1,1],
                     [0,0,0,0,1,1,1,1],
                     [0,0,1,1,1,1,1,1],
                     [0,1,1,1,1,1,1,1],
                     [0,0,0,0,0,0,1,1]])

How do I get mask using idx with vectorized code?

import torch

batch_size = 5

idx = torch.tensor([3,4,1,1,6])

# I changed the mask a little

mask = torch.tensor([[0,0,0,1,1,1,1,1],

                     [0,0,0,0,1,1,1,1],

                     [0,0,1,1,1,1,1,1],

                     [0,1,1,1,1,1,1,1],

                     [0,0,0,0,0,0,1,1]])

print(torch.gather(mask, dim=1, index=idx.unsqueeze(-1).expand(mask.size(0), -1)))

Not getting the expected output.
Mask’s value in my question is the expected output. You already seem to have initialized the mask with the expected output.

It seems not easy for current pytorch. Because current operations only accept scalars as the parameters instead of tensor. Something like torch.range(tensor start, tensor end) sovles your problem. You can also refer to https://github.com/pytorch/nestedtensor (which is still experimental)

This piece of code may be a good start:

torch.repeat_interleave(torch.tensor([0, 1, 0, 1]), torch.tensor([3, 8-3 4, 8-4])).view(2, 8)
# gives 
# tensor([[0, 0, 0, 1, 1, 1, 1, 1],
#         [0, 0, 0, 0, 1, 1, 1, 1]])

We can obtain the one-line code:

torch.repeat_interleave(torch.tensor([0, 1]*batch_size), torch.stack([idx, 8-idx], dim=1).view(-1)).view(batch_size, 8)
1 Like

@pfloat’s solution with torch.repeat_interleave is quite interesting. However, I found something simple.

Upper triangular matrix

t = 8
tri = torch.triu( torch.ones(8,8), diagonal=1 )
tri
tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

Close but if the idx tensor says 3 then index 3 must also be masked.

tri.fill_diagonal_(1)
tri
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.]])

Now this upper triangular matrix can be referenced to create a mask

tri[3]
tensor([0., 0., 0., 1., 1., 1., 1., 1.])

idx = torch.tensor([4,1,2,5,3,6])
tri[idx]
tensor([[0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1.]])

Since we only need 0 or 1 we can make to dtype bool or byte and this way the method can work for large values of t.

2 Likes

@sidwa0 Nice solution!