I am experimenting with MultiHeadAttention and tried the code below. However, I couldn’t get the attn_mask
argument to work as what I’m thinking.
import torch
import torch.nn as nn
tgt = torch.rand(2,3,4)
src = torch.rand(2,3,4)
src[:, 0, :] = torch.nan
print(f"tgt: {tgt.shape}, {tgt}")
print(f"src: {src.shape}, {src}")
print("First example")
mask = torch.zeros(2, 3, 3, dtype=torch.bool)
mask[:, :, 0] = True
print(f"mask: {mask.shape}, {mask.dtype}, {mask}")
multihead_attn = nn.MultiheadAttention(embed_dim=4, num_heads=1, batch_first=True)
attn_out, attn_weights = multihead_attn(tgt, src, src, attn_mask=mask)
print(f"multihead_attn attn_weights: {attn_weights.shape}, {attn_weights}")
print(f"multihead_attn out: {attn_out.shape}, {attn_out}")
print("Second example")
mask = torch.zeros(2, 3, 3, dtype=torch.bool)
mask[:, 0, :] = True
print(f"mask: {mask.shape}, {mask.dtype}, {mask}")
multihead_attn = nn.MultiheadAttention(embed_dim=4, num_heads=1, batch_first=True)
attn_out, attn_weights = multihead_attn(tgt, src, src, attn_mask=mask)
print(f"multihead_attn attn_weights: {attn_weights.shape}, {attn_weights}")
print(f"multihead_attn out: {attn_out.shape}, {attn_out}")
Output:
tgt: torch.Size([2, 3, 4]), tensor([[[0.5204, 0.8147, 0.7542, 0.5528],
[0.9538, 0.2621, 0.3353, 0.3032],
[0.3012, 0.1474, 0.5201, 0.7111]],
[[0.4666, 0.8606, 0.8641, 0.0374],
[0.3397, 0.3314, 0.7925, 0.1301],
[0.5559, 0.0732, 0.4313, 0.4095]]])
src: torch.Size([2, 3, 4]), tensor([[[ nan, nan, nan, nan],
[0.6643, 0.2720, 0.9359, 0.5026],
[0.5256, 0.6496, 0.8701, 0.9192]],
[[ nan, nan, nan, nan],
[0.9559, 0.8751, 0.0662, 0.0216],
[0.2466, 0.2948, 0.5645, 0.5404]]])
First example
mask: torch.Size([2, 3, 3]), torch.bool, tensor([[[ True, False, False],
[ True, False, False],
[ True, False, False]],
[[ True, False, False],
[ True, False, False],
[ True, False, False]]])
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],
[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], grad_fn=<MeanBackward1>)
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]],
[[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]]], grad_fn=<TransposeBackward0>)
Second example
mask: torch.Size([2, 3, 3]), torch.bool, tensor([[[ True, True, True],
[False, False, False],
[False, False, False]],
[[ True, True, True],
[False, False, False],
[False, False, False]]])
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],
[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]]], grad_fn=<MeanBackward1>)
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]],
[[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]]], grad_fn=<TransposeBackward0>)
If I understood the code correctly on MultiHeadAttention at here, the nan
values should be masked before a Softmax is computed. I am not sure why in my case the nan
values in src
is not masked in both examples by the MultiHeadAttention?