Scaled_dot_product_attention higher head num cost much more memory

Hi, I found Scaled_dot_product_attention cost much more memory when head number is large(>=16). This is my code to reproduce the issue.


import torch
length = 10000
dim = 64
head_num1 = 8
head_num2 = 16
batch = 1
shapes = [[batch, head_num1, length, dim//head_num1], [batch, head_num2, length, dim//head_num2]]
for shape in shapes:
    torch.cuda.reset_peak_memory_stats()

    shape2 = [1,1, length, length]
    q = torch.rand(shape, dtype = torch.float16).cuda()
    k= torch.rand(shape, dtype = torch.float16).cuda()
    v = torch.rand(shape, dtype = torch.float16).cuda()

    attn_mask = torch.ones(shape2, dtype = torch.bool).cuda()
    x = torch.nn.functional.scaled_dot_product_attention(
                                q, k, v, attn_mask=attn_mask, dropout_p=0)

    peak_memory = torch.cuda.max_memory_allocated()      
    print(f"head number {shape[1]} case peak memory: {peak_memory / 1e6:.2f} MB")

environment:

Python Version: 3.9.19
PyTorch Version: 2.1.0+cu121
CUDA Available: Yes
CUDA Version: 12.1
Current GPU: Tesla V100-SXM2-16GB
CUDDN: 8902

output:

head number 8 case peak memory: 405.17 MB
head number 16 case peak memory: 6716.14 MB

just double the head number need 16x more memory… is it normal?

I tried on another machine

Python Version: 3.9.19
PyTorch Version: 2.1.0+cu121
CUDA Available: Yes
CUDA Version: 12.1
Current GPU: NVIDIA GeForce GTX 1070
CUDDN: 8902

output:

head number 8 case peak memory: 405.17 MB
head number 16 case peak memory: 406.49 MB

I think this should be right but I failed to achieve this on many other machines. Hope someone help, many thanks.

this is because dim = 64, head_num2 = 16, 64 // 16 = 4 and 4 is not divisible by 8. Pytorch becomes inefficient in this case.

To avoid this, also need to set dim = 128 as 128 // 16 = 8.