Hello everybody,
I tried to search around for something similar but could not find something similar.
I am having trouble when using an anti causal attention mask as attn_mask(the transpose of the causal one to attend just to the future elements) combined with the key padding mask. I have at the output that the padded embeddings take a nan value. I don’t understand what I am doing wrong. Some suggestions?
I leave the code below:
from torch.nn import MultiheadAttention
import torch
torch.manual_seed(42)
B, seq_len, embed_dim = 1, 5, 5
mha = MultiheadAttention(embed_dim=embed_dim, num_heads=1, batch_first=True)
x = torch.rand((B, seq_len, embed_dim))
mask = torch.tril(torch.ones(x.size(1), x.size(1)) * float("-inf"), diagonal=-1)
key_padding_mask = torch.tensor([[0, 0, 0, 0, float("-inf")]])
res = mha(x, x, x, attn_mask=mask, key_padding_mask=key_padding_mask)
print(key_padding_mask)
print(mask)
print(res[0])
tensor([[0., 0., 0., 0., -inf]])
tensor([[0., 0., 0., 0., 0.],
[-inf, 0., 0., 0., 0.],
[-inf, -inf, 0., 0., 0.],
[-inf, -inf, -inf, 0., 0.],
[-inf, -inf, -inf, -inf, 0.]])
tensor([[[-0.1386, 0.0392, -0.0841, -0.0103, 0.0435],
[-0.1620, 0.0555, -0.1101, -0.0131, 0.0673],
[-0.1162, 0.0409, -0.0969, -0.0043, 0.0651],
[-0.1792, 0.0674, -0.0597, -0.0299, 0.0199],
[ nan, nan, nan, nan, nan]]],
grad_fn=<TransposeBackward0>)