I just wanted to confirm what is the best way to ensure that only the new Flash Attention in PyTorch 2.0 is being used for scaled dot product attention:
For example:
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = flash_attn_dropout,
is_causal = causal,
scale = scale
)
i compared true, false, false (=force to use flash attention) vs false, true, true. as you said, i expected the former should be faster, but but it’s slightly slower (second/iteration is like 7 vs 6).
i also ran @ptrblck 's code snippet to get the same results when printing the lines.
i’m using torch2.0, a100-80gb. is there anything i might be missing?
That is interesting. It should be both much faster and more memory-optimized on an A100 (80GB) due to the increased bandwidth. Have you opened an issue on the PyTorch github with the benchmarks for your tests?
that’s interesting! i didn’t open an issue there (yet). just fyi (or for anyone), the training was with deepspeed 2 and 3, batch size per gpu 2 something, but n_head is like 40 hence the number – i believe, n_head * batch_size, that matters – should be large enough.
I wrote the following toy snippet to eval flash-attention speed up. The code outputs
Flash attention took 0.0018491744995117188 seconds
Standard attention took 0.6876699924468994 seconds
Notice the following
1- I am using float16 on cuda, because flash-attention supports float16 and bfloat16
2- Flash-attention aggregates multiple operations into a single fused-kernel. Thus, more operations leads to more savings. In my code snippet, I am doing matmul, softmax, dropout only. I believe further speed up can be gained by adding the mask operation as well.
3- Flash-attention can support longer sequences that standard attention can’t. For instance, my GPU can perform flash-attention with seq_len=4096, but throws OOM error with standard attention
import time
import torch
import torch.nn.functional as F
bz = 32
seq_len = 2048
dims = 64
n_heads = 8
q = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()
k = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()
v = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()
dropout_rate = 0.2
num_trials = 10
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
):
start = time.time()
for i in range(num_trials):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_rate)
end = time.time()
print('Flash attention took {} seconds'.format(end - start))
start = time.time()
for i in range(num_trials):
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = F.dropout(attn, p=dropout_rate, training=True)
x = (attn @ v).transpose(1, 2) # .reshape(bz, seq_len, n_heads*dims)
end = time.time()
print('Standard attention took {} seconds'.format(end - start))
CUDA kernels are executed asynchronously so you would need to synchronize the code before starting and stopping the host timers. Otherwise you would profile the dispatching, kernel launches, or implicit syncs making your profile invalid.
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal = False)
I try run code, but error, the error info here:
<string>:1: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:545.)
<string>:1: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:338.)
<string>:1: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:547.)
<string>:1: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:191.)
Traceback (most recent call last):
File "<string>", line 1, in <module>
RuntimeError: No available kernel. Aborting execution.
but, when I remove atten_mask paramters, it work.
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal = False)
attention_mask shape is [bz, seq_len. target_len ,src_len]
Noob question. Why cant you set everything to True? like so. Wouldn’t that make it more memory efficient?
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
flash / math / mem_efficient are different backends. By setting all to True, you are letting PyTorch choose the most favorable one; and by setting one to True, you are forcing a backend and letting it fail if not available. Usually you want to force flash attention for the best speed and check why it may fail.