Hi,
I have observed that in my script the usage of scaled_dot_product_attention
slows my code down a bit.
Then I have tested if the implementation of scaled_dot_product_attention is actually more efficient in my case, turn out: no. The manual implementation is faster. I this a bug, or should it be expected? I have written something similar yesterday ( Bug? RMSNorm slower than LayerNorm - #3 by JosAr ) for RMSNorm and the newest torch version fixed it; but it does not fix the behaviour I will describe now:
I tested a very easy manual dot product attention, vs the unoptimised code given in torch.nn.functional.scaled_dot_product_attention — PyTorch 2.8 documentation vs the optimised one with the following script:
import torch
import torch.nn.functional as F
import time
import math
device = 'cuda'
batch_size, seq_len, d_model = 32, 512, 768
shape = [12, 6, 80, 60, 32]
shape = [6, 12, 40, 60, 32]
num_heads = 8
d_k = 32 // num_heads
# Generate random data
q = torch.randn(shape, device=device)
k = torch.randn(shape, device=device)
v = torch.randn(shape, device=device)
def manual_attention(q, k, v):
scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, v)
# Inefficient implementation equivalent to SDPA:
def manual_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
def benchmark(func, *args, num_runs=500, warmup=10):
# Warmup
for _ in range(warmup):
_ = func(*args)
torch.cuda.synchronize()
start = time.time()
for _ in range(num_runs):
_ = func(*args)
torch.cuda.synchronize()
return (time.time() - start) / num_runs
# Benchmark both approaches
manual_time = benchmark(manual_attention, q, k, v)
builtin_time = benchmark(F.scaled_dot_product_attention, q, k, v)
inefficient_time = benchmark(manual_scaled_dot_product_attention, q, k, v)
print(torch.__version__)
print(f"Manual implementation: {manual_time*1000:.2f} ms")
print(f"Built-in SDPA: {builtin_time*1000:.2f} ms")
print(f"inefficient time: {inefficient_time*1000:.2f} ms")
And have received the following results on an A100, using torch 2.9.0.dev20250813+cu129
shape | Manual | Inefficient | SDPA |
---|---|---|---|
[6, 12, 40, 60, 32] | 0.51 ms | 0.53 ms | 0.65 ms |
[12, 6, 80, 60, 32] | 0.93 ms | 1.03 ms | 1.26 ms |
Shouldn’t the efficient SDPA be faster; at least compared to the “inefficient” implementation?
Furthermore Id like to make a feature request (if that’s possible?): Is it possible to make the already implemented attn_bias
an optional argument of SDPA? Papers like the SWIN transformer (https://arxiv.org/pdf/2103.14030) give a positional Bias (equation (4)). It would be great to have an efficient implementation of that as well!
Thanks a lot already!