Compiling a model that uses sdp_kernel to enable the backends does not work

the model uses “with sdp_kernel(enable_math=True, enable_flash=True, enable_mem_efficient=True):”

when calling compile on this model, it generated the following error:

torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor _GeneratorContextManager call_function <function sdp_kernel at 0x7e3c02b79fc0>

how can i use sdp_kernel with compile

thanks

torch.compile works for me but raises the deprecation warning:

query = torch.rand(64, 12, 77, 64, dtype=torch.float16, device="cuda")
key = torch.rand(64, 12, 77, 64, dtype=torch.float16, device="cuda")
value = torch.rand(64, 12, 77, 64, dtype=torch.float16, device="cuda")
attn_mask = torch.rand(64, 1, 77, 77, dtype=torch.float16, device="cuda")


def fun(query, key, value, attn_mask):
    with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=True, enable_mem_efficient=True):
        out = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
    return out

out = fun(query, key, value, attn_mask)
# FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.

fun_compiled = torch.compile(fun)
out = fun_compiled(query, key, value, attn_mask)
# FutureWarning: `torch.backends.cuda.sdp_kernel()` is deprecated. In the future, this context manager will be removed. Please see `torch.nn.attention.sdpa_kernel()` for the new context manager, with updated signature.
1 Like

Can you move the context manager so it is outside of the compiled region? E.g

With sdp_kerbel(…):
out = compiled_model(inp)

Although cc @drisspg (we should make sure the new, non-deprecated context manager works inside of a compiled region)

1 Like

sorry, forgot to mention that it was with fullgraph.

fun_compiled = torch.compile(fun, fullgraph=True)

it yields:

File “/home/…/env/lib/python3.12/site-packages/torch/_dynamo/exc.py”, line 297, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor _GeneratorContextManager call_function <function sdp_kernel at 0x7fa07c0c8400>

from user code:
File “/home/…/test.py”, line 10, in fun
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=True, enable_mem_efficient=True):

i guess i can do that as a workaround. will try.