Context
Hi, I am trying to move our model from triton’s flash attention to torch2 flash attention, to benefit from torch.compile
!
However the problem lies in attention mask. Our model uses attention biasing, which I need to integrate into attn_mask
parameter. Our model is also autoregressive, and since is_causal
and attn_mask
can’t be combined, I integrated causal masking into attn_mask
. Unfortunately, it seems that flash attention currently does not support float attention masking (attn_mask
parameter)?
Environment info
GPU: RTX 3090 (but I got the same error with our code on A100)
torch==2.0.1
Minimal Example
import torch
from torch.backends.cuda import SDPBackend, sdp_kernel
import torch.nn.functional as F
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}
# doesn't work
# shows UserWarning: Both fused kernels do not support non-null attn_mask
# raises RuntimeError: No available kernel. Aborting execution.
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
query = torch.randn(1, 10, 1024, 64).cuda()
key = torch.randn(1, 10, 1024, 64).cuda()
value = torch.randn(1, 10, 1024, 64).cuda()
attn_mask = torch.randn(1, 10, 1024, 1024).cuda()
causal_mask = ~torch.ones(query.shape[2], key.shape[2], dtype=torch.bool).tril(diagonal=0).to(attn_mask.device)
attn_mask = attn_mask.masked_fill(causal_mask.view(1, 1, query.shape[2], key.shape[2]), -float("inf"))
dropout_p = 0.0
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
result = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p)
Is there any way to circumvent this limitation?