Hi, I was trying to use the pytorch implementation of multihead attention with masks (“key_padding_mask”) for the padded tokens. However, it seems the mask is doing weird things. Here is an example:
I have three tokens, each with embedding dimension of 4. Assuming the last token is padding, I added the mask [False, False, True]. I would expect the attention weight for the masked token to be zero
import torch
from torch import optim, nn
self_attn = nn.MultiheadAttention(4,2, dropout=0, batch_first=True)
x=torch.rand((3,4)) #sequence length x embed_d
res, attn=self_attn(x,x,x,key_padding_mask=torch.Tensor([False,False,True]),need_weights=True)
Then attention matrix gives a result like following:
attn= tensor([[0.2125, 0.2203, 0.5673],
[0.2041, 0.2077, 0.5881],
[0.2086, 0.2112, 0.5803]], grad_fn=<SqueezeBackward1>)
The attention weight for the masked token (last column) is not zero, but a big value.
What is going on?