Bug? scaled_dot_product_attention slower than manual multiplication?

Hi,

I have observed that in my script the usage of scaled_dot_product_attention slows my code down a bit.

Then I have tested if the implementation of scaled_dot_product_attention is actually more efficient in my case, turn out: no. The manual implementation is faster. I this a bug, or should it be expected? I have written something similar yesterday ( Bug? RMSNorm slower than LayerNorm - #3 by JosAr ) for RMSNorm and the newest torch version fixed it; but it does not fix the behaviour I will describe now:

I tested a very easy manual dot product attention, vs the unoptimised code given in torch.nn.functional.scaled_dot_product_attention — PyTorch 2.8 documentation vs the optimised one with the following script:

import torch
import torch.nn.functional as F
import time
import math

device = 'cuda'
batch_size, seq_len, d_model = 32, 512, 768
shape = [12, 6, 80, 60, 32]
shape = [6, 12, 40, 60, 32]
num_heads = 8
d_k = 32 // num_heads

# Generate random data
q = torch.randn(shape, device=device)
k = torch.randn(shape, device=device)
v = torch.randn(shape, device=device)

def manual_attention(q, k, v):
    scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
    attn_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attn_weights, v)

# Inefficient implementation equivalent to SDPA:
def manual_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

def benchmark(func, *args, num_runs=500, warmup=10):
    # Warmup
    for _ in range(warmup):
        _ = func(*args)

    torch.cuda.synchronize()

    start = time.time()
    for _ in range(num_runs):
        _ = func(*args)

    torch.cuda.synchronize()

    return (time.time() - start) / num_runs

# Benchmark both approaches
manual_time = benchmark(manual_attention, q, k, v)
builtin_time = benchmark(F.scaled_dot_product_attention, q, k, v)
inefficient_time = benchmark(manual_scaled_dot_product_attention, q, k, v)


print(torch.__version__)
print(f"Manual implementation: {manual_time*1000:.2f} ms")
print(f"Built-in SDPA: {builtin_time*1000:.2f} ms")
print(f"inefficient time: {inefficient_time*1000:.2f} ms")

And have received the following results on an A100, using torch 2.9.0.dev20250813+cu129

shape Manual Inefficient SDPA
[6, 12, 40, 60, 32] 0.51 ms 0.53 ms 0.65 ms
[12, 6, 80, 60, 32] 0.93 ms 1.03 ms 1.26 ms

Shouldn’t the efficient SDPA be faster; at least compared to the “inefficient” implementation?

Furthermore Id like to make a feature request (if that’s possible?): Is it possible to make the already implemented attn_bias an optional argument of SDPA? Papers like the SWIN transformer (https://arxiv.org/pdf/2103.14030) give a positional Bias (equation (4)). It would be great to have an efficient implementation of that as well!

Thanks a lot already! :slight_smile:

Edit: If its due to the shapes of key, query, value tensors; there are also efficient implementations for that (I’ve seen GitHub - zzd1992/FlashWindowAttention: Speedup the attention computation of Swin Transformer ) - will those be integrated in the future? Still I thought SDPA should be on-par with the “inefficient implementation”.

No, it is not a bug.

SDPA is much faster (efficient) with causal attention

For correct benchmarking, the shape variable in your script must be modified to the appropriate format.

shape = [60, 8, 512, 32] # batch_size, num_heads, seq_length, head_dim

Now execute it. The table below shows the comparison (in ms) as the sequence length increases from 512 to 1024 to 4096

512 1024 4096
Manual Implementation (no causal masking) 13.67 53.86 140.16
Inefficient Implementation (causal masking) 18.03 70.76 214.05
SDPA (causal masking) 3.95 15.4 40.38

Your implementation did not include a causal mask. However, adding a causal mask to the score adds additional FLOPS (latency) that increase quadratically with increasing sequence length. Therefore, compared to an inefficient implementation, the efficient SDPA is five times faster!

Here are the results from running your script on my device (L40 GPU). The only change I made was modifying the shape variable as mentioned above

Hope it helps!

Regards

~Arun

Hi!

Thanks for the answer! I have progressed in finding the issue, but it still exists.

Indeed the shape variables I have used earlier were the ones I want to use; in the examplary code there was unfortunately one line of legacy code that might have confused you (sorry for that), but I am interested in the runtimes of shape = [12, 6, 80, 60, 32]and shape = [6, 12, 40, 60, 32]. The Last dimension is the feature dimension E=32, the second last is the sequence length (L=60).

All other dimension (6 heads, 12 batches, 80 “sliding windows” - i am working on some kind of image project) should be treated in my understanding of attention as a kind of batch; if heads or batches it should computation-wise not matter if I am mistaken.

I further broke it down to be closer to your example, and the documentation:

If you run shape = [24, 6, 2048, 32] (N, H, L, E) or shape = [24, 1, 6, 2048, 32] (N, …, H,L,E) should not make any difference, if I read from torch.nn.functional.scaled_dot_product_attention — PyTorch 2.8 documentation correctly, right? There is only one additional dimension of size 1.

For me it does though, but only for the build in SPDA, not for the others.

What am I missing?