How to Control ATen Ops for `sdpa` in PyTorch Export

I would like to know how to control which ATen operation is used during torch.export when exporting the scaled_dot_product_attention operation. It seems that sometimes I get aten._scaled_dot_product_flash_attention_for_cpu, while other times I get aten.scaled_dot_product_attention. I’m not sure what factors cause this difference.

Could someone please help point me to the related documentation? Thank you.

You could try to select the desired algorithm via torch.backends.cuda.sdp_kernel but unsure if this would work with export and for your use case.

1 Like