I am experimenting with MultiHeadAttention and tried the code below. However, I couldn’t get the `attn_mask` argument to work as what I’m thinking.

``````import torch
import torch.nn as nn

tgt = torch.rand(2,3,4)
src = torch.rand(2,3,4)
src[:, 0, :] = torch.nan
print(f"tgt: {tgt.shape}, {tgt}")
print(f"src: {src.shape}, {src}")

print("First example")

mask = torch.zeros(2, 3, 3, dtype=torch.bool)

print("Second example")

mask = torch.zeros(2, 3, 3, dtype=torch.bool)

``````

Output:

``````tgt: torch.Size([2, 3, 4]), tensor([[[0.5204, 0.8147, 0.7542, 0.5528],
[0.9538, 0.2621, 0.3353, 0.3032],
[0.3012, 0.1474, 0.5201, 0.7111]],

[[0.4666, 0.8606, 0.8641, 0.0374],
[0.3397, 0.3314, 0.7925, 0.1301],
[0.5559, 0.0732, 0.4313, 0.4095]]])
src: torch.Size([2, 3, 4]), tensor([[[   nan,    nan,    nan,    nan],
[0.6643, 0.2720, 0.9359, 0.5026],
[0.5256, 0.6496, 0.8701, 0.9192]],

[[   nan,    nan,    nan,    nan],
[0.9559, 0.8751, 0.0662, 0.0216],
[0.2466, 0.2948, 0.5645, 0.5404]]])
First example
mask: torch.Size([2, 3, 3]), torch.bool, tensor([[[ True, False, False],
[ True, False, False],
[ True, False, False]],

[[ True, False, False],
[ True, False, False],
[ True, False, False]]])
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],

[[nan, nan, nan],
[nan, nan, nan],
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]],

[[nan, nan, nan, nan],
[nan, nan, nan, nan],

Second example
mask: torch.Size([2, 3, 3]), torch.bool, tensor([[[ True,  True,  True],
[False, False, False],
[False, False, False]],

[[ True,  True,  True],
[False, False, False],
[False, False, False]]])
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],

[[nan, nan, nan],
[nan, nan, nan],
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]],

[[nan, nan, nan, nan],
[nan, nan, nan, nan],
``````

If I understood the code correctly on MultiHeadAttention at here, the `nan` values should be masked before a Softmax is computed. I am not sure why in my case the `nan` values in `src` is not masked in both examples by the MultiHeadAttention?

Be aware that: “For a binary mask, a `True` value indicates that the corresponding position is not allowed to attend.” Your masks are mostly False, which signifies not masking in most cases.

I installed PyTorch from source to debug what’s going on. It looks like when there are `nan` values in the tensor, the values produced by `torch.baddbmm` at here produces `nan` values in the tensor causing the softmax to produce `nan` everywhere after that. Setting the `nan` values to `0` before feeding it to the multi head attention seems to work. Example as below. I’m not sure whether this is an intended behaviour or a bug.

``````import torch
import torch.nn as nn

torch.manual_seed(42)

tgt = torch.rand(2,3,4)
src = torch.rand(2,3,4)
src[:, 0, :] = torch.nan
print(f"tgt: {tgt.shape}, {tgt}")
print(f"src: {src.shape}, {src}")

mask = torch.zeros(2, 3, 3, dtype=torch.bool)

print("First example")

print("Second example")

src_replace_nan = torch.where(torch.isnan(src), 0, src)
print(f"src_replace_nan: {src_replace_nan.shape}, {src_replace_nan}")

``````

Output:

``````tgt: torch.Size([2, 3, 4]), tensor([[[0.8823, 0.9150, 0.3829, 0.9593],
[0.3904, 0.6009, 0.2566, 0.7936],
[0.9408, 0.1332, 0.9346, 0.5936]],

[[0.8694, 0.5677, 0.7411, 0.4294],
[0.8854, 0.5739, 0.2666, 0.6274],
[0.2696, 0.4414, 0.2969, 0.8317]]])
src: torch.Size([2, 3, 4]), tensor([[[   nan,    nan,    nan,    nan],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090]],

[[   nan,    nan,    nan,    nan],
[0.6343, 0.3644, 0.7104, 0.9464],
[0.7890, 0.2814, 0.7886, 0.5895]]])
mask: torch.Size([2, 3, 3]), torch.bool, tensor([[[ True, False, False],
[ True, False, False],
[ True, False, False]],

[[ True, False, False],
[ True, False, False],
[ True, False, False]]])
First example
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[nan, nan, nan],
[nan, nan, nan],
[nan, nan, nan]],

[[nan, nan, nan],
[nan, nan, nan],
multihead_attn out: torch.Size([2, 3, 4]), tensor([[[nan, nan, nan, nan],
[nan, nan, nan, nan],
[nan, nan, nan, nan]],

[[nan, nan, nan, nan],
[nan, nan, nan, nan],
Second example
src_replace_nan: torch.Size([2, 3, 4]), tensor([[[0.0000, 0.0000, 0.0000, 0.0000],
[0.5472, 0.0062, 0.9516, 0.0753],
[0.8860, 0.5832, 0.3376, 0.8090]],

[[0.0000, 0.0000, 0.0000, 0.0000],
[0.6343, 0.3644, 0.7104, 0.9464],
[0.7890, 0.2814, 0.7886, 0.5895]]])
multihead_attn attn_weights: torch.Size([2, 3, 3]), tensor([[[0.0000, 0.3902, 0.6098],
[0.0000, 0.4472, 0.5528],
[0.0000, 0.4268, 0.5732]],

[[0.0000, 0.5090, 0.4910],
[0.0000, 0.5182, 0.4818],