As far as I know, all the current SDPA backends are written in C++ (e.g., flash_attention
and mem_efficient_attention
) or provided by C++ libraries (e.g., CuDNN). As the SDPA backend selection code is written in C++, it can only call backends via C++ APIs. If I have an SDPA backend completely written in Python, how can I add it to PyTorch?