nn.MultiheadAttention with float mask results in 'nan' for all attention scores

I try to use nn.MultiheadAttention with some float masks, on PyTorch 2.0.1, like the below demo code:

import torch
import torch.nn as nn

my_attn = nn.MultiheadAttention(
my_attn = my_attn.to("cuda")

x = torch.randn((1, 10, 256), dtype=torch.float, device="cuda")
bool_mask = torch.ones((10, 10), dtype=torch.bool, device="cuda")
bool_mask[:2, :2] = False
bool_mask[2:, 2:] = False

float_mask = torch.randn((10, 10), dtype=torch.float, device="cuda")

with torch.no_grad():
    bool_mask_res, bool_mask_scores = my_attn(
        query=x, key=x, value=x, attn_mask=bool_mask
    float_mask_res, float_mask_scores = my_attn(
        query=x, key=x, value=x, attn_mask=float_mask
    print(float_mask_scores)       # NOTE: the bug is here, the scores will be all 'nan'.

The result is that the float_mask_scores is filled with nan, which is obviously wrong. And I also get a UserWarning:

torch/nn/modules/activation.py:1160: UserWarning: Converting mask without torch.bool dtype to bool; this will negatively affect performance. Prefer to use a boolean mask directly. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343995026/work/aten/src/ATen/native/transformers/attention.cpp:150.)
  return torch._native_multi_head_attention(

It seems that the float_mask is regarded as a bool mask in the built-in code.
I guess there is caused by scaled_dot_product_attention() because I can run this code correctly in PyTorch 1.13.1.

This is a bug of PyTorch 2.0.1.

As discussed in: