Flash Attention

Hi @ptrblck,

I just wanted to confirm what is the best way to ensure that only the new Flash Attention in PyTorch 2.0 is being used for scaled dot product attention:

For example:

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(
    enable_flash=True, 
    enable_math=False, 
    enable_mem_efficient=False
):
    out = F.scaled_dot_product_attention(
        q, k, v,
        attn_mask = mask,
        dropout_p = flash_attn_dropout, 
        is_causal = causal, 
        scale = scale
    )

I greatly appreciate your help.

Thank you,

Enrico

torch.backends.cuda.enable_flash_sdp is not a context manager and you could use with torch.backends.cuda.sdp_kernel instead as sen here:

print(torch.backends.cuda.flash_sdp_enabled())
# True
print(torch.backends.cuda.mem_efficient_sdp_enabled())
# True
print(torch.backends.cuda.math_sdp_enabled())
# True

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
    print(torch.backends.cuda.flash_sdp_enabled())
    # True
    print(torch.backends.cuda.mem_efficient_sdp_enabled())
    # False
    print(torch.backends.cuda.math_sdp_enabled())
    # False
4 Likes

Thank you for verifying it should be used as such:

with torch.backends.cuda.sdp_kernel(
    enable_flash=True, 
    enable_math=False, 
    enable_mem_efficient=False
):

Appreciation as always.

Best,

Enrico

2 Likes

Hi @EnricoShippole, did you get any performance improvement by this change to force flash attention only?

Hi @kchoi ,

You should see both improvements in speed and memory consumption. You can check out some recent small baseline models I trained here: GitHub - conceptofmind/PaLM: An open-source implementation of Google's PaLM models

Thank you,

Enrico

thanks a lot!

i compared true, false, false (=force to use flash attention) vs false, true, true. as you said, i expected the former should be faster, but but it’s slightly slower (second/iteration is like 7 vs 6).

i also ran @ptrblck 's code snippet to get the same results when printing the lines.

i’m using torch2.0, a100-80gb. is there anything i might be missing?

That is interesting. It should be both much faster and more memory-optimized on an A100 (80GB) due to the increased bandwidth. Have you opened an issue on the PyTorch github with the benchmarks for your tests?

I have been speaking to a few different peers and they are noticing results similar to yours. I will have to test the Triton version by Tri Dao too.

that’s interesting! i didn’t open an issue there (yet). just fyi (or for anyone), the training was with deepspeed 2 and 3, batch size per gpu 2 something, but n_head is like 40 hence the number – i believe, n_head * batch_size, that matters – should be large enough.

Can you try with a head dim of 128?

I wrote the following toy snippet to eval flash-attention speed up. The code outputs

Flash attention took 0.0018491744995117188 seconds
Standard attention took 0.6876699924468994 seconds

Notice the following
1- I am using float16 on cuda, because flash-attention supports float16 and bfloat16
2- Flash-attention aggregates multiple operations into a single fused-kernel. Thus, more operations leads to more savings. In my code snippet, I am doing matmul, softmax, dropout only. I believe further speed up can be gained by adding the mask operation as well.
3- Flash-attention can support longer sequences that standard attention can’t. For instance, my GPU can perform flash-attention with seq_len=4096, but throws OOM error with standard attention

import time
import torch
import torch.nn.functional as F


bz = 32
seq_len = 2048
dims = 64
n_heads = 8
q = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()
k = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()
v = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()

dropout_rate = 0.2
num_trials = 10

with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    start = time.time()
    for i in range(num_trials):
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_rate)
    end = time.time()
    print('Flash attention took {} seconds'.format(end - start))


start = time.time()
for i in range(num_trials):
    attn = q @ k.transpose(-2, -1)
    attn = attn.softmax(dim=-1)
    attn = F.dropout(attn, p=dropout_rate, training=True)
    x = (attn @ v).transpose(1, 2)  # .reshape(bz, seq_len, n_heads*dims)
end = time.time()
print('Standard attention took {} seconds'.format(end - start))

1 Like

CUDA kernels are executed asynchronously so you would need to synchronize the code before starting and stopping the host timers. Otherwise you would profile the dispatching, kernel launches, or implicit syncs making your profile invalid.

2 Likes

Thanks for catching this issue. I updated the code accordingly. Please let me know if you see other mistakes.

I also switched the order of standard and flash attention evaluations as a sanity check.
The current output is

Standard attention took 0.8632566928863525 seconds for 10 trials
Flash attention took 0.07728338241577148 seconds for 10 trials

The updated code snippet is

import time
import torch
import torch.nn.functional as F


bz = 32
seq_len = 2048
dims = 64
n_heads = 8
q = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()
k = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()
v = torch.randn(bz, n_heads, seq_len, dims, dtype=torch.float16).cuda()

dropout_rate = 0.2
num_trials = 10


torch.cuda.synchronize()
start = time.time()
for i in range(num_trials):
    attn = q @ k.transpose(-2, -1)
    attn = attn.softmax(dim=-1)
    attn = F.dropout(attn, p=dropout_rate, training=True)
    x = (attn @ v).transpose(1, 2)  # .reshape(bz, seq_len, n_heads*dims)
torch.cuda.synchronize()
end = time.time()
print('Standard attention took {} seconds for {} trials'.format(end - start, num_trials))

with torch.backends.cuda.sdp_kernel(
    enable_flash=True, enable_math=False, enable_mem_efficient=False
):
    torch.cuda.synchronize()
    start = time.time()
    for i in range(num_trials):
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_rate)
    torch.cuda.synchronize()
    end = time.time()
    print('Flash attention took {} seconds for {} trials'.format(end - start, num_trials))

2 Likes
with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_math=False, enable_mem_efficient=False
        ):
            out = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal = False)

I try run code, but error, the error info here:

<string>:1: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:545.)
<string>:1: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:338.)
<string>:1: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:547.)
<string>:1: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:191.)
Traceback (most recent call last):
  File "<string>", line 1, in <module>
RuntimeError: No available kernel.  Aborting execution.

but, when I remove atten_mask paramters, it work.

with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_math=False, enable_mem_efficient=False
        ):
       out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal = False)

attention_mask shape is [bz, seq_len. target_len ,src_len]

2 Likes

if I want to add atten_mask parameter, what shuold I do?

1 Like

I know its been forever.
But only the math and meff kernel supports the attn_mask parameter.

Noob question. Why cant you set everything to True? like so. Wouldn’t that make it more memory efficient?
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):