torch.compile works for me but raises the deprecation warning:
query = torch.rand(64, 12, 77, 64, dtype=torch.float16, device="cuda")
key = torch.rand(64, 12, 77, 64, dtype=torch.float16, device="cuda")
value = torch.rand(64, 12, 77, 64, dtype=torch.float16, device="cuda")
attn_mask = torch.rand(64, 1, 77, 77, dtype=torch.float16, device="cuda")
def fun(query, key, value, attn_mask):
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=True, enable_mem_efficient=True):
out = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
return out
out = fun(query, key, value, attn_mask)
# FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
fun_compiled = torch.compile(fun)
out = fun_compiled(query, key, value, attn_mask)
# FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
sorry, forgot to mention that it was with fullgraph.
fun_compiled = torch.compile(fun, fullgraph=True)
it yields:
File “/home/…/env/lib/python3.12/site-packages/torch/_dynamo/exc.py”, line 297, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor _GeneratorContextManager call_function <function sdp_kernel at 0x7fa07c0c8400>
from user code:
File “/home/…/test.py”, line 10, in fun
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=True, enable_mem_efficient=True):