Explicitly forcing torch's MHA to use Flash Attention

How can we force torch to use new SDPA implementation in torch.nn.multiheadattention()?

I know that it is supposed to use it automatically, but I’d still prefer an explicit version.
There’s also a ctxmanager, but I’m unsure where its supposed to be wrapped. Using it in the foward() doesn’t change anything, nor does simply stating:

torch.backends.cuda.enable_flash_sdp(enabled=True)

somewhere in the code.

Curious why you want to do this since I’d imagine that it would produce incorrect results?

Why would that be? The new SPDA doesn’t really tradeoff precision AFAIK.

I don’t know the specifics, but I’d imagine that the speed ups are enabled by making specific assumptions about the inputs, and if those assumptions are violated, correctness may not be guaranteed. If you have a use case that you’d think should be supported, but is not currently supported, maybe you can file an issue?

Doesn’t seem worth it tbh. maybe @ptrblck has some idea what I should do?