Yes, that’s correct. I was thinking that doing it explicitly might help. So I separated the concerns and ran the following code for a quick check and found that the implementation has no issues.
import time
import math
import torch
import torch.nn.functional as F
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
def benchmark(func, *args, num_runs=500, warmup=10):
# Warmup
for _ in range(warmup):
_ = func(*args)
torch.cuda.synchronize()
start = time.time()
for _ in range(num_runs):
_ = func(*args)
torch.cuda.synchronize()
return (time.time() - start) / num_runs
# Benchmark both approaches
builtin_time = benchmark(F.scaled_dot_product_attention, query, key, value)
manual_time = benchmark(scaled_dot_product_attention, query, key, value)
print(f"Built-in time: {builtin_time*1000:.2f} ms")
print(f"manual time: {manual_time*1000:.2f} ms")
I got the following results,
Built-in time: 0.04 ms
manual time: 0.17 ms
There might be an issue (which is often subtle) in your code. Carefully check the shapes passed to the SDPA function and also the arguments used in the MultiheadAttention module (if you’re using that in your model. For reference, you might find the following threads helpful: