When I tried decompose aten._scaled_dot_product_flash_attention.default
with my own backend using get_decompositions([aten._scaled_dot_product_flash_attention.default])
, I get the following error
File "torch/_functorch/_aot_autograd/functional_utils.py", line 335, in assert_functional_graph
assert (
torch._dynamo.exc.BackendCompilerFailed: backend='mybackend' raised:
AssertionError: aot_autograd expected to have an entirely functional graph, but found %masked_fill_ : [num_users=1] = call_function[target=torch.ops.aten.masked_fill_.Scalar](args = (%zeros_like, %logical_not, -inf), kwargs = {})
How can I decompose aten._scaled_dot_product_flash_attention.default
? This is with torch 2.2.0.