Accelerate attention by SDPA

Yes, that’s correct. I was thinking that doing it explicitly might help. So I separated the concerns and ran the following code for a quick check and found that the implementation has no issues.

import time
import math
import torch
import torch.nn.functional as F

query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

def benchmark(func, *args, num_runs=500, warmup=10):
    # Warmup
    for _ in range(warmup):
        _ = func(*args)

    torch.cuda.synchronize()

    start = time.time()
    for _ in range(num_runs):
        _ = func(*args)

    torch.cuda.synchronize()

    return (time.time() - start) / num_runs

# Benchmark both approaches
builtin_time = benchmark(F.scaled_dot_product_attention, query, key, value)
manual_time = benchmark(scaled_dot_product_attention, query, key, value)


print(f"Built-in time: {builtin_time*1000:.2f} ms")
print(f"manual time: {manual_time*1000:.2f} ms")

I got the following results,

Built-in time: 0.04 ms
manual time: 0.17 ms

There might be an issue (which is often subtle) in your code. Carefully check the shapes passed to the SDPA function and also the arguments used in the MultiheadAttention module (if you’re using that in your model. For reference, you might find the following threads helpful:

1 Like