Hi @ptrblck,
I just wanted to confirm what is the best way to ensure that only the new Flash Attention in PyTorch 2.0 is being used for scaled dot product attention:
For example:
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = flash_attn_dropout,
is_causal = causal,
scale = scale
)
I greatly appreciate your help.
Thank you,
Enrico
torch.backends.cuda.enable_flash_sdp
is not a context manager and you could use with torch.backends.cuda.sdp_kernel
instead as sen here:
print(torch.backends.cuda.flash_sdp_enabled())
# True
print(torch.backends.cuda.mem_efficient_sdp_enabled())
# True
print(torch.backends.cuda.math_sdp_enabled())
# True
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
print(torch.backends.cuda.flash_sdp_enabled())
# True
print(torch.backends.cuda.mem_efficient_sdp_enabled())
# False
print(torch.backends.cuda.math_sdp_enabled())
# False
1 Like
Thank you for verifying it should be used as such:
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
Appreciation as always.
Best,
Enrico
1 Like