I would like to know how to control which ATen operation is used during torch.export
when exporting the scaled_dot_product_attention
operation. It seems that sometimes I get aten._scaled_dot_product_flash_attention_for_cpu
, while other times I get aten.scaled_dot_product_attention
. I’m not sure what factors cause this difference.
Could someone please help point me to the related documentation? Thank you.