Transformer src_key_padding_mask

Hi,

I don’t understand how to use src_key_padding_mask to pad inputs. In the following piece of code (pytorch 1.7.1), I expected the 3 calls to tfm to yield the same output. Why is it not the case?

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
dim_emb = 4
n_head = 1
n_layers = 1
dim_ff = dim_emb*4
dropout = 0.0
activation = 'gelu'

tfm = nn.Transformer(
    d_model=dim_emb,
    nhead=n_head,
    num_encoder_layers=n_layers,
    num_decoder_layers=n_layers,
    dim_feedforward=dim_ff,
    dropout=dropout,
    activation=activation,
    #  batch_first=False,
).eval()

L = 6
bs = 7
src = torch.rand((bs, L, dim_emb))
x = torch.rand((bs, 9, dim_emb))
i = 2
padding = torch.tensor([False, False, True, True, True, True]).unsqueeze(0).expand((bs, L))
#  padding = torch.tensor([False, False]).unsqueeze(0).expand((bs, L))
print(src.size(), x.size(), padding.size())
y = tfm(src=src.transpose(0, 1),
        tgt=x.transpose(0, 1),
        src_mask=None,
        src_key_padding_mask=padding,
).transpose(0, 1)
print(y[0])

# instead of padding, cut the inputs to be length 2
# since everything after position 2 is padded, it should be equivalent?
y = tfm(src=src.transpose(0, 1)[:i],
        tgt=x.transpose(0, 1),
        src_mask=None,
        src_key_padding_mask=None,
).transpose(0, 1)
print(y[0])

# modify the inputs that is padded. shouldn't change anything, right?
src2 = src.masked_fill(padding.unsqueeze(2), 0.)
y = tfm(src=src2.transpose(0, 1),
        tgt=x.transpose(0, 1),
        src_mask=None,
        src_key_padding_mask=padding,
).transpose(0, 1)
print(y[0])

The output gives:

torch.Size([7, 6, 4]) torch.Size([7, 9, 4]) torch.Size([7, 6])
tensor([[ 1.4638,  0.1487, -1.3234, -0.2891],
        [ 0.2777, -0.8721, -0.9246,  1.5190],
        [-1.3183,  1.1202,  0.8018, -0.6037],
        [ 0.5094, -0.7163, -1.1626,  1.3695],
        [-0.4719,  1.2851,  0.5408, -1.3540],
        [-0.0649, -0.0610,  1.4744, -1.3484],
        [-0.5059,  1.4019,  0.3818, -1.2779],
        [ 1.3614, -0.8594,  0.5496, -1.0516],
        [ 1.0727, -0.7819,  0.9024, -1.1932]], grad_fn=<SelectBackward>)
tensor([[ 1.3663,  0.3374, -1.3847, -0.3191],
        [ 0.3984, -0.6862, -1.1462,  1.4340],
        [-1.1728,  1.3390,  0.5563, -0.7225],
        [ 0.6021, -0.5122, -1.3432,  1.2533],
        [-0.3347,  1.3814,  0.3227, -1.3695],
        [ 0.0232,  0.0239,  1.3902, -1.4374],
        [-0.3574,  1.4819,  0.1602, -1.2847],
        [ 1.4358, -0.8001,  0.4226, -1.0582],
        [ 1.1649, -0.7335,  0.7872, -1.2187]], grad_fn=<SelectBackward>)
tensor([[ 1.3616,  0.3624, -1.3762, -0.3477],
        [ 0.4203, -0.6421, -1.1903,  1.4121],
        [-1.1537,  1.3216,  0.5899, -0.7577],
        [ 0.6180, -0.4711, -1.3745,  1.2275],
        [-0.3194,  1.3595,  0.3485, -1.3886],
        [ 0.0432,  0.0389,  1.3720, -1.4541],
        [-0.3418,  1.4603,  0.1909, -1.3093],
        [ 1.4626, -0.7706,  0.3708, -1.0628],
        [ 1.1972, -0.7074,  0.7416, -1.2314]], grad_fn=<SelectBackward>)

Anybody knows what’s going on? Thanks a lot!

It now works as intended by passing memory_key_padding_mask=padding to the tfm call.