Please consider the following TransformerEncoderLayer
which is used in two ways, with a full mask and a causal mask:
src_length = 6
embedding_size = 12
batch_size = 1
transformer_encoder_layer = TransformerEncoderLayer(embedding_size, 2, 8, activation='gelu', norm_first=True)
transformer_encoder_layer.eval()
x = torch.rand((src_length,batch_size,embedding_size))
full_mask = torch.zeros((src_length,src_length))
transformer_encoder_layer(x, src_mask=full_mask)[:,:,-1]
# tensor([[-0.0359],
# [ 0.3602],
# [ 0.5291],
# [ 0.2098],
# [ 0.4393],
# [ 0.4099]], grad_fn=<SelectBackward0>)
causal_mask = torch.triu(torch.full((src_length, src_length), float('-inf')), diagonal=1)
transformer_encoder_layer(x, src_mask=causal_mask)[:,:,-1]
# tensor([[-0.1079],
# [ 0.4843],
# [ 0.6760],
# [ 0.1818],
# [ 0.5266],
# [ 0.4099]], grad_fn=<SelectBackward0>)
As expected, it returns the same results for the last element in the sequence since both in the causal and full masked case, it is able to attend to the whole sequence while for previous elements, the results differ.
However, where it gets unexpected is when I do the same with a TransformerEncoder
:
src_length = 6
embedding_size = 12
batch_size = 1
transformer_encoder_layer = TransformerEncoderLayer(embedding_size, 2, 8, activation='gelu', norm_first=True)
transformer_net = TransformerEncoder(transformer_encoder_layer, 6)
transformer_net.eval()
x = torch.rand((src_length,batch_size,embedding_size))
full_mask = torch.zeros((src_length,src_length))
transformer_net(x, mask=full_mask)[:,:,-1]
# tensor([[1.1491],
# [1.3810],
# [1.4263],
# [1.8420],
# [1.8647],
# [1.9420]], grad_fn=<SelectBackward0>)
causal_mask = torch.triu(torch.full((src_length, src_length), float('-inf')), diagonal=1)
transformer_net(x, mask=causal_mask)[:,:,-1]
# tensor([[-0.6918],
# [-0.2881],
# [ 0.0460],
# [ 0.1590],
# [-0.4355],
# [-0.2541]], grad_fn=<SelectBackward0>)
where you can see that I get totally different results, even for the last dimension. What is the reason for that and is there a way to avoid that?