How to add padding mask to nn.TransformerEncoder module?

I want to use vanilla transformer(only the encoder side), but I don’t know how&where to add the padding mask.

6 Likes

I am also facing the same trouble, did you find any solutions? @ptrblck could you please give some time, if you are available. Thanks in advance.

Specifically, I am facing trouble to understand how to provide padded sequence mask in TransformerEncoderLayer? In TransformerEncoderLayer there are two mask parameters: src_mask and src_key_padding_mask, what will be content(is it boolean or -inf/0) and shape? which parameters is responsible for the sequence padding mask?

4 Likes

I think, when using src_mask, we need to provide a matrix of shape (S, S), where S is our source sequence length, for example,

import torch, torch.nn as nn
q = torch.randn(3, 1, 10) # source sequence length 3, batch size 1, embedding size 10
attn = nn.MultiheadAttention(10, 1) # embedding size 10, one head
attn(q, q, q) # self attention

for attn_mask, we need matrix of shape (S, S),

def src_mask(sz):
  mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  return mask
src_mask(3)

gives

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])
attn(q, q, q, attn_mask=src_mask(3))[1] # attention output weights

gives

tensor([[[1.0000, 0.0000, 0.0000],
          [0.4679, 0.5321, 0.0000],
          [0.3934, 0.3740, 0.2326]]], grad_fn=<DivBackward0>)

if we look at F.multi_head_attention_forward, then what attn_mask is doing is,

if attn_mask is not None:
        attn_mask = attn_mask.unsqueeze(0)
        attn_output_weights += attn_mask

as we added

float('-inf')

to some of the weights, so, when we do softmax, then it returns zero, for example,

a = nn.Softmax(dim=-1)
b = torch.tensor([3., 4., float('-inf')])
a(b)

tensor([0.2689, 0.7311, 0.0000])

which means that we are not considering some words when finding the representation for a word, for example, when finding attn_weights for first word in our source sentence, we do not want to consider next words, for finding attn_weights for second word in our our source sentence, we want to consider only first and second word, and not third word.

as for, src_key_padding_mask, it has to be of shape (N, S), where N is batch size, and S is source sequence length.
I think it is to make us not consider any padded words for finding representation of other words.
for example, if we want to not consider third word in our source sequence, for finding attention weights, then, (batch size of 1)

src_key_padding_mask = torch.tensor([[0, 0, 1]]).bool()
attn(q, q, q, attn_mask=src_mask(3), key_padding_mask=src_key_padding_mask)[1]

gives

tensor([[[1.0000, 0.0000, 0.0000],
         [0.4679, 0.5321, 0.0000],
         [0.5127, 0.4873, 0.0000]]], grad_fn=<DivBackward0>)

the third column is always zero, as we did not consider what impact the third word has on the representation of other words.

18 Likes

Thanks a lot @vainaijr for the concise explanation :smiley:

1 Like

Could someone say more about why exactly is the N batch size included in the src_key_padding_mask dimension then?

Im getting as mask:

tensor([[False],
        [False],
        [False],
        [False],
        [True],
        [True]], device='cuda:0')

of shape (batch_size, target_length), where False is a non-padding token and True marks padding tokens. Is that correct?

i think batch size here works like number of sentences, so if we specify

q = torch.randn(3, 4, 10) # source sequence length 3, batch size 4, embedding size 10

then it would mean something like this,

first sentence -> I eat fruit.

second sentence -> He go out.

third sentence -> You did well.

fourth sentence -> They are here.

attn = MultiheadAttention(10, 2) # embedding size 10, two head
# which mean, each word would be represented by 10 numbers, and two head mean that, we divide these 10 numbers into two groups of 5
self_attn = attn(q, q, q)

printing values would give,

q.shape:  torch.Size([8, 3, 5]) batch_size*num_heads, target_length, embedding_size_for_each_head
k.shape:  torch.Size([8, 3, 5]) 
v.shape:  torch.Size([8, 3, 5])
attn_output.shape (the learnt representation of the 12 words) torch.Size([3, 4, 10]) # target_length, batch_size, embedding_size_total
attn_output_weights.shape (the impact of each word in a sentence on each word) torch.Size([4, 3, 3]) # batch_size, target_length, source_length
attn_output tensor([[[-0.1518, -0.1691,  0.0284, -0.0418,  0.0363, -0.1349, -0.0698,
           0.1362,  0.0447,  0.0937],
         [ 0.1111, -0.0311,  0.1227,  0.0483, -0.0948, -0.0280, -0.0651,
          -0.0260,  0.1558, -0.0052],
         [ 0.0418, -0.0926,  0.0445,  0.4777, -0.0827,  0.1084,  0.1056,
          -0.0652,  0.0163, -0.1016],
         [ 0.0301,  0.1693,  0.2875,  0.1530, -0.5764, -0.0436, -0.5433,
           0.2245,  0.7345, -0.2269]],

        [[-0.1020, -0.1816,  0.1032,  0.0121, -0.0234, -0.1059, -0.0716,
           0.1583,  0.1086,  0.0887],
         [ 0.1091, -0.0786,  0.0826,  0.1906, -0.1552,  0.0493, -0.0796,
           0.0217,  0.2836, -0.0477],
         [ 0.0909, -0.1036, -0.0246,  0.4748, -0.0393,  0.1554,  0.2303,
          -0.0569, -0.1089, -0.0686],
         [ 0.0811,  0.2390,  0.3590,  0.3205, -0.6570,  0.0034, -0.5537,
           0.1301,  0.7532, -0.2921]],

        [[-0.1200, -0.1573,  0.0923,  0.0381, -0.0095, -0.0785, -0.0614,
           0.1465,  0.0847,  0.0854],
         [ 0.0930, -0.0545,  0.0920,  0.1169, -0.1343, -0.0320, -0.0359,
           0.0426,  0.1419, -0.0135],
         [-0.0124, -0.1950, -0.0552,  0.4579, -0.1132,  0.1364,  0.1459,
           0.0463, -0.0546, -0.0625],
         [-0.0338,  0.1369,  0.2610, -0.0500, -0.4532, -0.1353, -0.5206,
           0.2944,  0.6668, -0.1662]]], grad_fn=<AddBackward0>) 
attn_output_weights tensor([[[0.3817, 0.2956, 0.3227],
         [0.3113, 0.3434, 0.3453],
         [0.3532, 0.3375, 0.3093]],

        [[0.2794, 0.3996, 0.3210],
         [0.3830, 0.2820, 0.3350],
         [0.3729, 0.2966, 0.3305]],

        [[0.3569, 0.4500, 0.1931],
         [0.2561, 0.4059, 0.3380],
         [0.4600, 0.2529, 0.2871]],

        [[0.3454, 0.3205, 0.3341],
         [0.4038, 0.3868, 0.2094],
         [0.2882, 0.2299, 0.4819]]], grad_fn=<DivBackward0>)

on applying src_mask the same way as above, we get,

tensor([[[1.0000, 0.0000, 0.0000],
         [0.3729, 0.6271, 0.0000],
         [0.4581, 0.2948, 0.2471]],

        [[1.0000, 0.0000, 0.0000],
         [0.5356, 0.4644, 0.0000],
         [0.3669, 0.3504, 0.2827]],

        [[1.0000, 0.0000, 0.0000],
         [0.5249, 0.4751, 0.0000],
         [0.3915, 0.3403, 0.2682]],

        [[1.0000, 0.0000, 0.0000],
         [0.3817, 0.6183, 0.0000],
         [0.3671, 0.3208, 0.3122]]], grad_fn=<DivBackward0>)

the 2nd and 3rd values for first row, and 3rd value for 2nd row, in each 3x3 tensor, that is for each sentence, is zeroed out.
when we do,

src_key_padding_mask = torch.tensor([[1, 0, 0], [0, 0, 0], [0, 1, 0], [0, 0, 1]]).bool()
attn(q, q, q, attn_mask=src_mask(3), key_padding_mask=src_key_padding_mask)[1]

it gives,

tensor([[[   nan,    nan,    nan],
         [0.0000, 1.0000, 0.0000],
         [0.0000, 0.5640, 0.4360]],

        [[1.0000, 0.0000, 0.0000],
         [0.5356, 0.4644, 0.0000],
         [0.3669, 0.3504, 0.2827]],

        [[1.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000],
         [0.5952, 0.0000, 0.4048]],

        [[1.0000, 0.0000, 0.0000],
         [0.3817, 0.6183, 0.0000],
         [0.5341, 0.4659, 0.0000]]], grad_fn=<DivBackward0>)

so here, for the first sentence, first row become all nan, because, we made all three values ‘-inf’, 2nd and 3rd value by attn_mask, and 1st value by src_key_padding_mask
that is we did something like,

x = nn.Softmax(dim=-1)
x(torch.tensor([float('-inf'), float('-inf'), float('-inf')]))

which give

tensor([nan, nan, nan])

for the second sentence, no word is padded by src_key_padding_mask, output be same as after applying attn_mask
for the third sentence, 2nd column is all zero
for the fourth sentence, 3rd column is all zero

but one problem with transformer be that, it do not consider order of words.

2 Likes