Bug? scaled_dot_product_attention slower than manual multiplication?

Hi,

I have observed that in my script the usage of scaled_dot_product_attention slows my code down a bit.

Then I have tested if the implementation of scaled_dot_product_attention is actually more efficient in my case, turn out: no. The manual implementation is faster. I this a bug, or should it be expected? I have written something similar yesterday ( Bug? RMSNorm slower than LayerNorm - #3 by JosAr ) for RMSNorm and the newest torch version fixed it; but it does not fix the behaviour I will describe now:

I tested a very easy manual dot product attention, vs the unoptimised code given in torch.nn.functional.scaled_dot_product_attention — PyTorch 2.8 documentation vs the optimised one with the following script:

import torch
import torch.nn.functional as F
import time
import math

device = 'cuda'
batch_size, seq_len, d_model = 32, 512, 768
shape = [12, 6, 80, 60, 32]
shape = [6, 12, 40, 60, 32]
num_heads = 8
d_k = 32 // num_heads

# Generate random data
q = torch.randn(shape, device=device)
k = torch.randn(shape, device=device)
v = torch.randn(shape, device=device)

def manual_attention(q, k, v):
    scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
    attn_weights = F.softmax(scores, dim=-1)
    return torch.matmul(attn_weights, v)

# Inefficient implementation equivalent to SDPA:
def manual_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

def benchmark(func, *args, num_runs=500, warmup=10):
    # Warmup
    for _ in range(warmup):
        _ = func(*args)

    torch.cuda.synchronize()

    start = time.time()
    for _ in range(num_runs):
        _ = func(*args)

    torch.cuda.synchronize()

    return (time.time() - start) / num_runs

# Benchmark both approaches
manual_time = benchmark(manual_attention, q, k, v)
builtin_time = benchmark(F.scaled_dot_product_attention, q, k, v)
inefficient_time = benchmark(manual_scaled_dot_product_attention, q, k, v)


print(torch.__version__)
print(f"Manual implementation: {manual_time*1000:.2f} ms")
print(f"Built-in SDPA: {builtin_time*1000:.2f} ms")
print(f"inefficient time: {inefficient_time*1000:.2f} ms")

And have received the following results on an A100, using torch 2.9.0.dev20250813+cu129

shape Manual Inefficient SDPA
[6, 12, 40, 60, 32] 0.51 ms 0.53 ms 0.65 ms
[12, 6, 80, 60, 32] 0.93 ms 1.03 ms 1.26 ms

Shouldn’t the efficient SDPA be faster; at least compared to the “inefficient” implementation?

Furthermore Id like to make a feature request (if that’s possible?): Is it possible to make the already implemented attn_bias an optional argument of SDPA? Papers like the SWIN transformer (https://arxiv.org/pdf/2103.14030) give a positional Bias (equation (4)). It would be great to have an efficient implementation of that as well!

Thanks a lot already! :slight_smile:

Edit: If its due to the shapes of key, query, value tensors; there are also efficient implementations for that (I’ve seen GitHub - zzd1992/FlashWindowAttention: Speedup the attention computation of Swin Transformer ) - will those be integrated in the future? Still I thought SDPA should be on-par with the “inefficient implementation”.