How to add padding mask to nn.TransformerEncoder module?

I think, when using src_mask, we need to provide a matrix of shape (S, S), where S is our source sequence length, for example,

import torch, torch.nn as nn
q = torch.randn(3, 1, 10) # source sequence length 3, batch size 1, embedding size 10
attn = nn.MultiheadAttention(10, 1) # embedding size 10, one head
attn(q, q, q) # self attention

for attn_mask, we need matrix of shape (S, S),

def src_mask(sz):
  mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  return mask
src_mask(3)

gives

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])
attn(q, q, q, attn_mask=src_mask(3))[1] # attention output weights

gives

tensor([[[1.0000, 0.0000, 0.0000],
          [0.4679, 0.5321, 0.0000],
          [0.3934, 0.3740, 0.2326]]], grad_fn=<DivBackward0>)

if we look at F.multi_head_attention_forward, then what attn_mask is doing is,

if attn_mask is not None:
        attn_mask = attn_mask.unsqueeze(0)
        attn_output_weights += attn_mask

as we added

float('-inf')

to some of the weights, so, when we do softmax, then it returns zero, for example,

a = nn.Softmax(dim=-1)
b = torch.tensor([3., 4., float('-inf')])
a(b)

tensor([0.2689, 0.7311, 0.0000])

which means that we are not considering some words when finding the representation for a word, for example, when finding attn_weights for first word in our source sentence, we do not want to consider next words, for finding attn_weights for second word in our our source sentence, we want to consider only first and second word, and not third word.

as for, src_key_padding_mask, it has to be of shape (N, S), where N is batch size, and S is source sequence length.
I think it is to make us not consider any padded words for finding representation of other words.
for example, if we want to not consider third word in our source sequence, for finding attention weights, then, (batch size of 1)

src_key_padding_mask = torch.tensor([[0, 0, 1]]).bool()
attn(q, q, q, attn_mask=src_mask(3), key_padding_mask=src_key_padding_mask)[1]

gives

tensor([[[1.0000, 0.0000, 0.0000],
         [0.4679, 0.5321, 0.0000],
         [0.5127, 0.4873, 0.0000]]], grad_fn=<DivBackward0>)

the third column is always zero, as we did not consider what impact the third word has on the representation of other words.

19 Likes