Hi, I found Scaled_dot_product_attention cost much more memory when head number is large(>=16). This is my code to reproduce the issue.
import torch
length = 10000
dim = 64
head_num1 = 8
head_num2 = 16
batch = 1
shapes = [[batch, head_num1, length, dim//head_num1], [batch, head_num2, length, dim//head_num2]]
for shape in shapes:
torch.cuda.reset_peak_memory_stats()
shape2 = [1,1, length, length]
q = torch.rand(shape, dtype = torch.float16).cuda()
k= torch.rand(shape, dtype = torch.float16).cuda()
v = torch.rand(shape, dtype = torch.float16).cuda()
attn_mask = torch.ones(shape2, dtype = torch.bool).cuda()
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=0)
peak_memory = torch.cuda.max_memory_allocated()
print(f"head number {shape[1]} case peak memory: {peak_memory / 1e6:.2f} MB")
environment:
Python Version: 3.9.19
PyTorch Version: 2.1.0+cu121
CUDA Available: Yes
CUDA Version: 12.1
Current GPU: Tesla V100-SXM2-16GB
CUDDN: 8902
output:
head number 8 case peak memory: 405.17 MB
head number 16 case peak memory: 6716.14 MB
just double the head number need 16x more memory… is it normal?
I tried on another machine
Python Version: 3.9.19
PyTorch Version: 2.1.0+cu121
CUDA Available: Yes
CUDA Version: 12.1
Current GPU: NVIDIA GeForce GTX 1070
CUDDN: 8902
output:
head number 8 case peak memory: 405.17 MB
head number 16 case peak memory: 406.49 MB
I think this should be right but I failed to achieve this on many other machines. Hope someone help, many thanks.