Scaled_dot_product_attention, bf16, with my attn_mask

Context

Hi, I am trying to move our model from triton’s flash attention to torch2 flash attention, to benefit from torch.compile!
However the problem lies in attention mask. Our model uses attention biasing, which I need to integrate into attn_mask parameter. Our model is also autoregressive, and since is_causal and attn_mask can’t be combined, I integrated causal masking into attn_mask. Unfortunately, it seems that flash attention currently does not support float attention masking (attn_mask parameter)?

Environment info

GPU: RTX 3090 (but I got the same error with our code on A100)
torch==2.0.1

Minimal Example

import torch
from torch.backends.cuda import SDPBackend, sdp_kernel
import torch.nn.functional as F

backend_map = {
    SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
    SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
# doesn't work
# shows UserWarning: Both fused kernels do not support non-null attn_mask
# raises RuntimeError: No available kernel.  Aborting execution.
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
    query = torch.randn(1, 10, 1024, 64).cuda()
    key = torch.randn(1, 10, 1024, 64).cuda()
    value = torch.randn(1, 10, 1024, 64).cuda()

    attn_mask = torch.randn(1, 10, 1024, 1024).cuda()
    causal_mask = ~torch.ones(query.shape[2], key.shape[2], dtype=torch.bool).tril(diagonal=0).to(attn_mask.device)
    attn_mask = attn_mask.masked_fill(causal_mask.view(1, 1, query.shape[2], key.shape[2]), -float("inf"))

    dropout_p = 0.0

    with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
        result = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p)

Is there any way to circumvent this limitation?

3 Likes

Hi just checking
Have you find a solution to this as I am facing similar issue

thanks

Hi, unfortunately no. I don’t think we can use torch’s flashattention this way yet :frowning:

Hi,
Yes the torch flash attention only accepts null mask I think so we gotta wait till Pytorch adds the mask feature.

Thanks

Hey @MFajcik1
Did you manage to overcome the issue somehow? Maybe there’s a third-party implementation of FlashAttention that supports arbitrary masking?