Attn_mask in nn.MultiheadAttention

Hi All,

Trying to understand why the mask does not seem to work. There is a similar question here that seems to resolve the issue. But same code below doesn’t return the 0s in upper triangle of the attention matrix. Torch version is 1.13.1.

inputs = torch.ones((4,3))
maten = nn.MultiheadAttention(embed_dim=3, num_heads=1)
attn_mask = torch.tril(torch.ones((4,4)))
attn_mask = attn_mask>0
query2,_ = maten(inputs, inputs, inputs, attn_mask=~attn_mask)
query2

tensor([[-0.0128,  0.0919,  0.1716],
        [-0.0128,  0.0919,  0.1716],
        [-0.0128,  0.0919,  0.1716],
        [-0.0128,  0.0919,  0.1716]], grad_fn=<SqueezeBackward1>)```

I think I might be wrong, because I expected 0s in upper triangle of the attention matrix. But I think that’s wrong. Since attn_mask is added to query*key product and then it’s a dot product between that and values, it should not be 0s in upper triangle?

I think it’s basically this and the term before ‘@’ has 0 above diagonal. Would appreciate if anyone could confirm.

attn_mask[attn_mask == 0]  =  -float("Inf")
softmax((Q@K.transpose(1,0)) +  attn_mask) @ V

Attention mask is a lower triangular matrix, and the entries that are non-zero should be -inf.

arr = [[-np.inf for _ in range(size)] for _ in range(size)]
arr = torch.tensor(arr)
mask = torch.triu(arr, diagonal=1) # Here diagonal=1 is used to make diagonal entries 0

When you print the mask, you get:

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

Size variable here depends on the dimensions of your Q and K matrices.
Your application of the mask is correct, it is added to the matrix multiplication of Q and K. However, you must scale Q@K.T before the masking.

1 Like

Thank you @Theo, that makes total sense. Since we are adding the two matrices we want 0s on the lower triangle.