Bug: scaled_dot_product_attention slower than matmul

Hi,
I have written this before and believe I found a / the same bug again: I have tried replacing my manual implementation (or the one given here torch.nn.functional.scaled_dot_product_attention — PyTorch 2.8 documentation as the standard version) with the efficient scaled_dot_product_attention and received much slower runtimes.

This time I have tried it with tensors of the shape (batch size, seq_length, dim): [28800, 8, 48], [1, 28800, 8, 48] or [28800, 1, 8, 48] (all the same, but an empty dimension added to hopefully suit the description of SDPA). However None of them were faster, but slower by factor 5 on my A100. Note that the added empty dimension is there to simulate different heads, as heads should be treated exactly as batches; however the results were devastating.

Then I tried it with another shape that occurs often in my project:

[115200, 8, 24] (like this SDPA was 5 times slower)

[115200, 1, 8, 24], and

[1, 115200, 8, 24]. However the two last attempts gave me this:

torch.AcceleratorError: CUDA error: invalid configuration argument
Search for cudaErrorInvalidConfiguration' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with TORCH_USE_CUDA_DSA` to enable device-side assertions.