Flash Attention

Hi @ptrblck,

I just wanted to confirm what is the best way to ensure that only the new Flash Attention in PyTorch 2.0 is being used for scaled dot product attention:

For example:

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, 
    enable_math=False, 
    enable_mem_efficient=False
):
    out = F.scaled_dot_product_attention(
        q, k, v,
        attn_mask = mask,
        dropout_p = flash_attn_dropout, 
        is_causal = causal, 
        scale = scale
    )

I greatly appreciate your help.

Thank you,

Enrico

torch.backends.cuda.enable_flash_sdp is not a context manager and you could use with torch.backends.cuda.sdp_kernel instead as sen here:

print(torch.backends.cuda.flash_sdp_enabled())
# True
print(torch.backends.cuda.mem_efficient_sdp_enabled())
# True
print(torch.backends.cuda.math_sdp_enabled())
# True

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    print(torch.backends.cuda.flash_sdp_enabled())
    # True
    print(torch.backends.cuda.mem_efficient_sdp_enabled())
    # False
    print(torch.backends.cuda.math_sdp_enabled())
    # False
1 Like

Thank you for verifying it should be used as such:

with torch.backends.cuda.sdp_kernel(
    enable_flash=True, 
    enable_math=False, 
    enable_mem_efficient=False
):

Appreciation as always.

Best,

Enrico

1 Like