Hi,
I don’t understand how to use src_key_padding_mask
to pad inputs. In the following piece of code (pytorch 1.7.1), I expected the 3 calls to tfm
to yield the same output. Why is it not the case?
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
dim_emb = 4
n_head = 1
n_layers = 1
dim_ff = dim_emb*4
dropout = 0.0
activation = 'gelu'
tfm = nn.Transformer(
d_model=dim_emb,
nhead=n_head,
num_encoder_layers=n_layers,
num_decoder_layers=n_layers,
dim_feedforward=dim_ff,
dropout=dropout,
activation=activation,
# batch_first=False,
).eval()
L = 6
bs = 7
src = torch.rand((bs, L, dim_emb))
x = torch.rand((bs, 9, dim_emb))
i = 2
padding = torch.tensor([False, False, True, True, True, True]).unsqueeze(0).expand((bs, L))
# padding = torch.tensor([False, False]).unsqueeze(0).expand((bs, L))
print(src.size(), x.size(), padding.size())
y = tfm(src=src.transpose(0, 1),
tgt=x.transpose(0, 1),
src_mask=None,
src_key_padding_mask=padding,
).transpose(0, 1)
print(y[0])
# instead of padding, cut the inputs to be length 2
# since everything after position 2 is padded, it should be equivalent?
y = tfm(src=src.transpose(0, 1)[:i],
tgt=x.transpose(0, 1),
src_mask=None,
src_key_padding_mask=None,
).transpose(0, 1)
print(y[0])
# modify the inputs that is padded. shouldn't change anything, right?
src2 = src.masked_fill(padding.unsqueeze(2), 0.)
y = tfm(src=src2.transpose(0, 1),
tgt=x.transpose(0, 1),
src_mask=None,
src_key_padding_mask=padding,
).transpose(0, 1)
print(y[0])
The output gives:
torch.Size([7, 6, 4]) torch.Size([7, 9, 4]) torch.Size([7, 6])
tensor([[ 1.4638, 0.1487, -1.3234, -0.2891],
[ 0.2777, -0.8721, -0.9246, 1.5190],
[-1.3183, 1.1202, 0.8018, -0.6037],
[ 0.5094, -0.7163, -1.1626, 1.3695],
[-0.4719, 1.2851, 0.5408, -1.3540],
[-0.0649, -0.0610, 1.4744, -1.3484],
[-0.5059, 1.4019, 0.3818, -1.2779],
[ 1.3614, -0.8594, 0.5496, -1.0516],
[ 1.0727, -0.7819, 0.9024, -1.1932]], grad_fn=<SelectBackward>)
tensor([[ 1.3663, 0.3374, -1.3847, -0.3191],
[ 0.3984, -0.6862, -1.1462, 1.4340],
[-1.1728, 1.3390, 0.5563, -0.7225],
[ 0.6021, -0.5122, -1.3432, 1.2533],
[-0.3347, 1.3814, 0.3227, -1.3695],
[ 0.0232, 0.0239, 1.3902, -1.4374],
[-0.3574, 1.4819, 0.1602, -1.2847],
[ 1.4358, -0.8001, 0.4226, -1.0582],
[ 1.1649, -0.7335, 0.7872, -1.2187]], grad_fn=<SelectBackward>)
tensor([[ 1.3616, 0.3624, -1.3762, -0.3477],
[ 0.4203, -0.6421, -1.1903, 1.4121],
[-1.1537, 1.3216, 0.5899, -0.7577],
[ 0.6180, -0.4711, -1.3745, 1.2275],
[-0.3194, 1.3595, 0.3485, -1.3886],
[ 0.0432, 0.0389, 1.3720, -1.4541],
[-0.3418, 1.4603, 0.1909, -1.3093],
[ 1.4626, -0.7706, 0.3708, -1.0628],
[ 1.1972, -0.7074, 0.7416, -1.2314]], grad_fn=<SelectBackward>)
Anybody knows what’s going on? Thanks a lot!