How can we force torch to use new SDPA implementation in torch.nn.multiheadattention()?
I know that it is supposed to use it automatically, but I’d still prefer an explicit version.
There’s also a ctxmanager, but I’m unsure where its supposed to be wrapped. Using it in the foward() doesn’t change anything, nor does simply stating:
I don’t know the specifics, but I’d imagine that the speed ups are enabled by making specific assumptions about the inputs, and if those assumptions are violated, correctness may not be guaranteed. If you have a use case that you’d think should be supported, but is not currently supported, maybe you can file an issue?