I try to use nn.MultiheadAttention with some float masks, on PyTorch 2.0.1, like the below demo code:
import torch
import torch.nn as nn
my_attn = nn.MultiheadAttention(
embed_dim=256,
batch_first=True,
num_heads=8
)
my_attn = my_attn.to("cuda")
x = torch.randn((1, 10, 256), dtype=torch.float, device="cuda")
bool_mask = torch.ones((10, 10), dtype=torch.bool, device="cuda")
bool_mask[:2, :2] = False
bool_mask[2:, 2:] = False
float_mask = torch.randn((10, 10), dtype=torch.float, device="cuda")
my_attn.eval()
with torch.no_grad():
bool_mask_res, bool_mask_scores = my_attn(
query=x, key=x, value=x, attn_mask=bool_mask
)
print(bool_mask_scores)
float_mask_res, float_mask_scores = my_attn(
query=x, key=x, value=x, attn_mask=float_mask
)
print(float_mask_scores) # NOTE: the bug is here, the scores will be all 'nan'.
The result is that the float_mask_scores
is filled with nan
, which is obviously wrong. And I also get a UserWarning:
torch/nn/modules/activation.py:1160: UserWarning: Converting mask without torch.bool dtype to bool; this will negatively affect performance. Prefer to use a boolean mask directly. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343995026/work/aten/src/ATen/native/transformers/attention.cpp:150.)
return torch._native_multi_head_attention(
It seems that the float_mask
is regarded as a bool mask in the built-in code.
I guess there is caused by scaled_dot_product_attention()
because I can run this code correctly in PyTorch 1.13.1.