No, it is not a bug.
SDPA is much faster (efficient) with causal attention
For correct benchmarking, the shape
variable in your script must be modified to the appropriate format.
shape = [60, 8, 512, 32] # batch_size, num_heads, seq_length, head_dim
Now execute it. The table below shows the comparison (in ms) as the sequence length increases from 512 to 1024 to 4096
512 | 1024 | 4096 | |
---|---|---|---|
Manual Implementation (no causal masking) | 13.67 | 53.86 | 140.16 |
Inefficient Implementation (causal masking) | 18.03 | 70.76 | 214.05 |
SDPA (causal masking) | 3.95 | 15.4 | 40.38 |
Your implementation did not include a causal mask. However, adding a causal mask to the score
adds additional FLOPS (latency) that increase quadratically with increasing sequence length. Therefore, compared to an inefficient implementation, the efficient SDPA is five times faster!
Here are the results from running your script on my device (L40 GPU). The only change I made was modifying the shape
variable as mentioned above
Hope it helps!
Regards
~Arun