Decomposition error with `scaled dot product flash attention`

When I tried decompose aten._scaled_dot_product_flash_attention.default with my own backend using get_decompositions([aten._scaled_dot_product_flash_attention.default]), I get the following error

File "torch/_functorch/_aot_autograd/", line 335, in assert_functional_graph
   assert (
torch._dynamo.exc.BackendCompilerFailed: backend='mybackend' raised:
AssertionError: aot_autograd expected to have an entirely functional graph, but found %masked_fill_ : [num_users=1] = call_function[target=torch.ops.aten.masked_fill_.Scalar](args = (%zeros_like, %logical_not, -inf), kwargs = {})

How can I decompose aten._scaled_dot_product_flash_attention.default? This is with torch 2.2.0.