MultiHeadAttention attn_mask argument doesn't seem to mask nan values

I am experimenting with MultiHeadAttention and tried the code below. However, I couldn’t get the attn_mask argument to work as what I’m thinking.

import torch
import torch.nn as nn


tgt = torch.rand(2,3,4)
src = torch.rand(2,3,4)
src[:, 0, :] = torch.nan
print(f"tgt: {tgt.shape}, {tgt}")
print(f"src: {src.shape}, {src}")

print("First example")

mask = torch.zeros(2, 3, 3, dtype=torch.bool)
mask[:, :, 0] = True

print(f"mask: {mask.shape}, {mask.dtype}, {mask}")

multihead_attn = nn.MultiheadAttention(embed_dim=4, num_heads=1, batch_first=True)
attn_out, attn_weights = multihead_attn(tgt, src, src, attn_mask=mask)

print(f"multihead_attn attn_weights: {attn_weights.shape}, {attn_weights}")
print(f"multihead_attn out: {attn_out.shape}, {attn_out}")

print("Second example")

mask = torch.zeros(2, 3, 3, dtype=torch.bool)
mask[:, 0, :] = True

print(f"mask: {mask.shape}, {mask.dtype}, {mask}")

multihead_attn = nn.MultiheadAttention(embed_dim=4, num_heads=1, batch_first=True)
attn_out, attn_weights = multihead_attn(tgt, src, src, attn_mask=mask)

print(f"multihead_attn attn_weights: {attn_weights.shape}, {attn_weights}")
print(f"multihead_attn out: {attn_out.shape}, {attn_out}")

Output:

tgt: torch.Size([2, 3, 4]), tensor([[[0.5204, 0.8147, 0.7542, 0.5528],
         [0.9538, 0.2621, 0.3353, 0.3032],
         [0.3012, 0.1474, 0.5201, 0.7111]],

        [[0.4666, 0.8606, 0.8641, 0.0374],
         [0.3397, 0.3314, 0.7925, 0.1301],
         [0.5559, 0.0732, 0.4313, 0.4095]]])
src: torch.Size([2, 3, 4]), tensor([[[   nan,    nan,    nan,    nan],
         [0.6643, 0.2720, 0.9359, 0.5026],
         [0.5256, 0.6496, 0.8701, 0.9192]],

        [[   nan,    nan,    nan,    nan],
         [0.9559, 0.8751, 0.0662, 0.0216],
         [0.2466, 0.2948, 0.5645, 0.5404]]])
First example
mask: torch.Size([2, 3, 3]), torch.bool, tensor([[[ True, False, False],
         [ True, False, False],
         [ True, False, False]],

        [[ True, False, False],
         [ True, False, False],
         [ True, False, False]]])
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]]], grad_fn=<MeanBackward1>)
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]]], grad_fn=<TransposeBackward0>)

Second example
mask: torch.Size([2, 3, 3]), torch.bool, tensor([[[ True,  True,  True],
         [False, False, False],
         [False, False, False]],

        [[ True,  True,  True],
         [False, False, False],
         [False, False, False]]])
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]]], grad_fn=<MeanBackward1>)
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]]], grad_fn=<TransposeBackward0>)

If I understood the code correctly on MultiHeadAttention at here, the nan values should be masked before a Softmax is computed. I am not sure why in my case the nan values in src is not masked in both examples by the MultiHeadAttention?

Be aware that: “For a binary mask, a True value indicates that the corresponding position is not allowed to attend.” Your masks are mostly False, which signifies not masking in most cases.

I installed PyTorch from source to debug what’s going on. It looks like when there are nan values in the tensor, the values produced by torch.baddbmm at here produces nan values in the tensor causing the softmax to produce nan everywhere after that. Setting the nan values to 0 before feeding it to the multi head attention seems to work. Example as below. I’m not sure whether this is an intended behaviour or a bug.

import torch
import torch.nn as nn

torch.manual_seed(42)

tgt = torch.rand(2,3,4)
src = torch.rand(2,3,4)
src[:, 0, :] = torch.nan
print(f"tgt: {tgt.shape}, {tgt}")
print(f"src: {src.shape}, {src}")

mask = torch.zeros(2, 3, 3, dtype=torch.bool)
mask[:, :, 0] = True
print(f"mask: {mask.shape}, {mask.dtype}, {mask}")

print("First example")

multihead_attn = nn.MultiheadAttention(embed_dim=4, num_heads=1, batch_first=True)
attn_out, attn_weights = multihead_attn(tgt, src, src, attn_mask=mask)

print(f"multihead_attn attn_weights: {attn_weights.shape}, {attn_weights}")
print(f"multihead_attn out: {attn_out.shape}, {attn_out}")

print("Second example")

src_replace_nan = torch.where(torch.isnan(src), 0, src)
print(f"src_replace_nan: {src_replace_nan.shape}, {src_replace_nan}")

multihead_attn = nn.MultiheadAttention(embed_dim=4, num_heads=1, batch_first=True)
attn_out, attn_weights = multihead_attn(tgt, src_replace_nan, src_replace_nan, attn_mask=mask)

print(f"multihead_attn attn_weights: {attn_weights.shape}, {attn_weights}")
print(f"multihead_attn out: {attn_out.shape}, {attn_out}")

Output:

tgt: torch.Size([2, 3, 4]), tensor([[[0.8823, 0.9150, 0.3829, 0.9593],
         [0.3904, 0.6009, 0.2566, 0.7936],
         [0.9408, 0.1332, 0.9346, 0.5936]],

        [[0.8694, 0.5677, 0.7411, 0.4294],
         [0.8854, 0.5739, 0.2666, 0.6274],
         [0.2696, 0.4414, 0.2969, 0.8317]]])
src: torch.Size([2, 3, 4]), tensor([[[   nan,    nan,    nan,    nan],
         [0.5472, 0.0062, 0.9516, 0.0753],
         [0.8860, 0.5832, 0.3376, 0.8090]],

        [[   nan,    nan,    nan,    nan],
         [0.6343, 0.3644, 0.7104, 0.9464],
         [0.7890, 0.2814, 0.7886, 0.5895]]])
mask: torch.Size([2, 3, 3]), torch.bool, tensor([[[ True, False, False],
         [ True, False, False],
         [ True, False, False]],

        [[ True, False, False],
         [ True, False, False],
         [ True, False, False]]])
First example
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]],

        [[nan, nan, nan],
         [nan, nan, nan],
         [nan, nan, nan]]], grad_fn=<MeanBackward1>)
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]]], grad_fn=<TransposeBackward0>)
Second example
src_replace_nan: torch.Size([2, 3, 4]), tensor([[[0.0000, 0.0000, 0.0000, 0.0000],
         [0.5472, 0.0062, 0.9516, 0.0753],
         [0.8860, 0.5832, 0.3376, 0.8090]],

        [[0.0000, 0.0000, 0.0000, 0.0000],
         [0.6343, 0.3644, 0.7104, 0.9464],
         [0.7890, 0.2814, 0.7886, 0.5895]]])
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[0.0000, 0.3902, 0.6098],
         [0.0000, 0.4472, 0.5528],
         [0.0000, 0.4268, 0.5732]],

        [[0.0000, 0.5090, 0.4910],
         [0.0000, 0.5182, 0.4818],
         [0.0000, 0.4987, 0.5013]]], grad_fn=<MeanBackward1>)
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[ 0.0439, -0.0275,  0.0880, -0.1783],
         [ 0.0369, -0.0280,  0.0826, -0.1601],
         [ 0.0394, -0.0278,  0.0845, -0.1666]],

        [[ 0.1194,  0.0144,  0.1713, -0.2186],
         [ 0.1204,  0.0149,  0.1723, -0.2194],
         [ 0.1184,  0.0139,  0.1702, -0.2178]]], grad_fn=<TransposeBackward0>)