Accelerate attention by SDPA

Hi,

I tried to use SDPA to calculate Attention part fast. I used time.time() to record the time of training one batchsize (I used gradient accumlation). Here is how I record the time.

    for epoch in range(config.epochs):
        if os.path.exists(config.model_save_path) and config.resume:
            losses = checkpoint['loss']
        else:
            losses = 0
       
        t = time.time()
        for idx, batch in enumerate(train_iter):
            if config.ddp:
                model.require_backward_grad_sync = (idx % config.gradient_accumulation_steps == config.gradient_accumulation_steps - 1)
            # batch_size seq
            b_token_ids = batch['input_ids'].t().to(config.device)  
            b_segs = batch['token_type_ids'].t().to(config.device)  
            b_mask = batch['attention_mask'].to(config.device)  
            b_mlm_label = batch['labels'].t().to(config.device)  


            with amp.autocast(device_type='cuda',dtype=torch.bfloat16):    # for forward function! bfloat16
                # with autocast(enabled=False):
                loss, mlm_logits = model(input_ids=b_token_ids,
                                         attention_mask=b_mask,
                                         token_type_ids=b_segs,
                                         masked_lm_labels=b_mlm_label,
                                         next_sentence_labels=None)

                loss = loss / config.gradient_accumulation_steps  # scale the 
            scaler.scale(loss).backward()

            if ((idx  % config.gradient_accumulation_steps) == (config.gradient_accumulation_steps - 1)):
                if config.grad_clip != 0.0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()
                print(f'Time for one step(one batchsize):{time.time() - t}')
                t = time.time()

            losses += (loss.item() * config.gradient_accumulation_steps)

            mlm_acc, _, _ = accuracy(mlm_logits, b_mlm_label, config.pad_index)
            ....
       

And I change my attention to SDPA:

This is my version:

        .....
        if attn_mask is not None:
            attn_output_weights += attn_mask  # [batch_size * num_heads, tgt_len, src_len]

        .....

        attn_output_weights = F.softmax(attn_output_weights, dim=-1) 
        attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=True) 
        attn_output = self.atten_qkv(v, attn_output_weights)

This is SDPA version:

    ....
       attn_output = F.scaled_dot_product_attention(q, k, v-self.atten_qkv.t_min_1, dropout_p=self.dropout)
....

Things are werid because these two version comsume almost the same time. I am not sure whether I use F.scaled_dot_product_attention in a right way. I tested on FP32, FP16, BF16, they all comsume almost the same time.

CUDA operations are executed asynchronously so you need to synchronize the code before starting and stopping host timers.

1 Like

Thank you for your reply!

After knowing the dfierences between GPU time and CPU time, I tried to use torch.cuda.synchronize() before recording time.

My code is based on gradient accumulation:

   for idx, batch in enumerate(train_iter):
        if config.ddp:
            model.require_backward_grad_sync = (idx % config.gradient_accumulation_steps == config.gradient_accumulation_steps - 1)
        b_token_ids = batch['input_ids'].t().to(config.device) 
        b_segs = batch['token_type_ids'].t().to(config.device)  
        b_mask = batch['attention_mask'].to(config.device)  
        b_mlm_label = batch['labels'].t().to(config.device)  

        if idx  % config.gradient_accumulation_steps == 0:#!!
            torch.cuda.synchronize()
            t0 = time.perf_counter()


        # with record_function("## forward ##"):
        # Use autocast context manager for automatic mixed precision
        with amp.autocast(device_type='cuda',dtype=torch.bfloat16):    # for forward function! bfloat16
            # with autocast(enabled=False):
            loss, mlm_logits = model(input_ids=b_token_ids,
                                     attention_mask=b_mask,
                                     token_type_ids=b_segs,
                                     masked_lm_labels=b_mlm_label,
                                     next_sentence_labels=None)

            loss = loss / config.gradient_accumulation_steps  # scale the loss to account for gradient accumulation

        scaler.scale(loss).backward()

        if ((idx  % config.gradient_accumulation_steps) == (config.gradient_accumulation_steps - 1)):
            if config.grad_clip != 0.0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            torch.cuda.synchronize()
            p = time.perf_counter()
            print(f'rank:{ddp_local_rank}!!!!!!!Time for one step(one batchsize):{p - t0}b16')#!!

Without sycronization (float32) CPU time:

5.6577911376953125

After sycronization (float32):

torch.compile() + SDPA:

rank:2!!!.Time for one step(one batchsize):4.333474938757718!!!
rank:0!!!.Time for one step(one batchsize):4.331189756747335!!!
rank:3!!!.Time for one step(one batchsize):4.331575924996287!!!
rank:1!!!.Time for one step(one batchsize):4.332561417948455!!!

SDPA:

rank:1!!!.Time for one step(one batchsize):5.672491875011474!!!
rank:0!!!.Time for one step(one batchsize):5.672714178916067!!!
rank:2!!!.Time for one step(one batchsize):5.672652362845838!!!
rank:3!!!.Time for one step(one batchsize):5.67308798385784!!!

without SPDA + without torch.compile():

rank:1!!!Time for one step(one batchsize):5.78729198500514
rank:2!!!Time for one step(one batchsize):5.791289611253887
rank:0!!!Time for one step(one batchsize):5.7841532728634775
rank:3!!!Time for one step(one batchsize):5.786291209049523

Without sycronization (bfloat16) CPU time:

2.8307549953460693

After sycronization (bfloat16):

SDPA:

rank:0!!!Time for one step(one batchsize):2.8250338532961905b16
rank:3!!!Time for one step(one batchsize):2.8336645336821675b16
rank:2!!!Time for one step(one batchsize):2.8370646508410573b16
rank:1!!!Time for one step(one batchsize):2.834382927045226b16

without SPDA + without torch.compile():

rank:1!!!Time for one step(one batchsize):2.828244622796774b16
rank:0!!!Time for one step(one batchsize):2.8276006169617176b16
rank:2!!!Time for one step(one batchsize):2.8268630262464285b16
rank:3!!!Time for one step(one batchsize):2.8273697779513896b16

I’m confused, beacause it seems no acceleration and CPU time looks like GPU time under most circumstances, and when floa32 wirh torch.compile() GPU time is smaller time, does that mean we need more CPU workers to process data faster?

And also, I compare the time between the atten part that defined by myself and SDPA:

import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6

SDPA:

print(
            f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, q, k, v-self.atten_qkv.t_min_1):.3f} microseconds")

return:

The default implementation runs in 2108.397 microseconds
The default implementation runs in 2111.590 microseconds
The default implementation runs in 2108.828 microseconds
The default implementation runs in 2112.032 microseconds
The default implementation runs in 2109.155 microseconds
The default implementation runs in 2110.308 microseconds

my version:

...
    def atten(self, q, k, v):
        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
        attn_output_weights = F.softmax(attn_output_weights, dim=-1)  # [batch_size * num_heads, tgt_len, src_len]
        attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=True)  # voltage
        attn_output = self.atten_qkv(v, attn_output_weights)
        return attn_output, attn_output_weights

    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
...
         print(
            f"The default implementation runs in {benchmark_torch_function_in_microseconds(self.atten, q, k, v):.3f} microseconds")
....

return:

The default implementation runs in 2106.173 microseconds
The default implementation runs in 2104.582 microseconds
The default implementation runs in 2107.339 microseconds
The default implementation runs in 2104.767 microseconds
The default implementation runs in 2106.384 microseconds
The default implementation runs in 2104.327 microseconds
The default implementation runs in 2105.738 microseconds
The default implementation runs in 2104.363 microseconds

I’m very confused, why my code cannot be accelerated.

Did you use SDPA within the context manager? like

from torch.nn.attention import SDPBackend, sdpa_kernel

# Only enable flash attention backend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    scaled_dot_product_attention(...)

# Enable the Math or Efficient attention backends
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
    scaled_dot_product_attention(...)

Please check out the documentation here. Sometimes, not using SDPA correctly can actually slow down performance.

Thank you for your reply. But I don’t think conntext manager is necessary. Because this shows that F.scaled_dot_product_attention can be applied alone and it is fast.

Yes, that’s correct. I was thinking that doing it explicitly might help. So I separated the concerns and ran the following code for a quick check and found that the implementation has no issues.

import time
import math
import torch
import torch.nn.functional as F

query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

def benchmark(func, *args, num_runs=500, warmup=10):
    # Warmup
    for _ in range(warmup):
        _ = func(*args)

    torch.cuda.synchronize()

    start = time.time()
    for _ in range(num_runs):
        _ = func(*args)

    torch.cuda.synchronize()

    return (time.time() - start) / num_runs

# Benchmark both approaches
builtin_time = benchmark(F.scaled_dot_product_attention, query, key, value)
manual_time = benchmark(scaled_dot_product_attention, query, key, value)


print(f"Built-in time: {builtin_time*1000:.2f} ms")
print(f"manual time: {manual_time*1000:.2f} ms")

I got the following results,

Built-in time: 0.04 ms
manual time: 0.17 ms

There might be an issue (which is often subtle) in your code. Carefully check the shapes passed to the SDPA function and also the arguments used in the MultiheadAttention module (if you’re using that in your model. For reference, you might find the following threads helpful:

1 Like

Thank you so much! I think I solved the issue!:smiley:

My previous q,k,v’s shape is [24*16, 512, 64] #[batch_size * num_heads,src_len, kdim].

And then I changed to [24, 16, 512, 64] # batch_size, num_heads, seq_length, head_dim. And it is really fast!

return:

The default implementation runs in 184.328 microseconds
The default implementation runs in 184.655 microseconds

compared to

A bit hard to believe, since the shapes of Q, K, and V really do matter! I’m curious why, because the fixed shape of fused kernel?

But things are weird, because although F.scaled_dot_product_attention, query, key, value) is fast, when I acculate one batch time, without F.scaled_dot_product_attention, query, key, value) is a little fast.

float32 :

my attention:

    def atten(self, q, k, v):
        attn_output_weights = torch.bmm(q, k.transpose(1, 2))
        attn_output_weights = F.softmax(attn_output_weights, dim=-1)  
        attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=True)  
        attn_output = self.atten_qkv(v, attn_output_weights)
        return attn_output, attn_output_weights
    def forward(..):
        ...
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        attn_output, attn_output_weights = self.atten(q, k, v-self.atten_qkv.t_min_1)    ## batch_size, num_heads, seq_length, head_dim
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        elapsed_us = (t1 - t0) * 1e6  
        print(f'atten time:{elapsed_us}')

return:

atten time:1663.760282099247 ms
atten time:1693.3688893914223 ms
atten time:1687.0391555130482 ms

F.scaled_dot_product_attention, query, key, value) version:

   def forward():
        ...
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        attn_output, attn_output_weights = self.atten(q, k, v-self.atten_qkv.t_min_1, dropout_p=self.dropout)    
        torch.cuda.synchronize()
        t1 = time.perf_counter()
        elapsed_us = (t1 - t0) * 1e6  
        print(f'atten time:{elapsed_us}')

return:

atten time:1054.6050034463406 ms
atten time:1051.4757595956326 ms
atten time:1051.2289591133595 ms

But when I caculate the batchtime:


    for epoch in range(config.epochs):
        if os.path.exists(config.model_save_path) and config.resume:
            losses = checkpoint['loss']
        else:
            losses = 0


        for idx, batch in enumerate(train_iter):
            if config.ddp:
                model.require_backward_grad_sync = (idx % config.gradient_accumulation_steps == config.gradient_accumulation_steps - 1)
            b_token_ids = batch['input_ids'].t().to(config.device)  
            b_segs = batch['token_type_ids'].t().to(config.device)  
            b_mask = batch['attention_mask'].to(config.device)  
            b_mlm_label = batch['labels'].t().to(config.device)  

            if idx  % config.gradient_accumulation_steps == 0:
                torch.cuda.synchronize()
                t0 = time.perf_counter()

            with amp.autocast(device_type='cuda',dtype=torch.float32):    
                # with autocast(enabled=False):
                loss, mlm_logits = model(input_ids=b_token_ids,
                                         attention_mask=b_mask,
                                         token_type_ids=b_segs,
                                         masked_lm_labels=b_mlm_label,
                                         next_sentence_labels=None)

                loss = loss / config.gradient_accumulation_steps  # scale the loss to account for gradient accumulation
            scaler.scale(loss).backward()

            if ((idx  % config.gradient_accumulation_steps) == (config.gradient_accumulation_steps - 1)):

                if config.grad_clip != 0.0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()
                torch.cuda.synchronize()
                p = time.perf_counter()
                print(f'rank:{ddp_local_rank}!!!!!!!Time for one step(one batchsize):{p - t0}b16')


            losses += (loss.item() * config.gradient_accumulation_steps)
            mlm_acc, _, _ = accuracy(mlm_logits, b_mlm_label, config.pad_index)
        ...
    

my attention return:
rank:1!!!Time for one step(one batchsize):5.563375188037753s
rank:0!!!Time for one step(one batchsize):5.564123823773116s
rank:0!!!Time for one step(one batchsize):5.565241441130638s
rank:1!!!Time for one step(one batchsize):5.5651940810494125s

F.scaled_dot_product_attention, query, key, value) version return:

rank:0!!!Time for one step(one batchsize):5.580610099248588s
rank:1!!!Time for one step(one batchsize):5.579201320186257s
rank:0!!!Time for one step(one batchsize):5.576219431124628s
rank:1!!!Time for one step(one batchsize):5.576044279150665s

A bit hard to believe, since the shapes of Q, K, and V really do matter! I’m curious why, because the fixed shape of fused kernel?

TokenNews