While implementing the batched matrix multiplication, i noticed that the batched matrix multiplication is not efficient, see the code below
import torch
# Input tensor
## Batch size=8192, dim=512
x = torch.FloatTensor(8192, 512).requires_grad_().cuda()
if True:
# Batch strategy 1
x1 = x.view(8192, 8, 1, 64) # 512 = 8 * 64
W1 = torch.FloatTensor(8, 64, 64).cuda()
out1 = torch.matmul(x1, W1) # out: [8192, 8, 1, 64]
print(torch.cuda.memory_allocated()) # 1107427328
if False:
# Batch strategy 2
x2 = x.view(8192, 1, 512) # add one dimension for batch matmul
W2 = torch.FloatTensor(512, 512).cuda() # larger than W1
# out: [8192, 1, 512] # the same number of elements as out1
out2 = torch.matmul(x2, W2)
print(torch.cuda.memory_allocated()) # 34603008
However, it turns out that Batch strategy 2 has less memory cost despite that W2 is larger than Batch strategy 1. And everything else are the same (x1, x2 have same number of elements, also out1, out2).
I also found that by removing the requires_grad_() the memory costs are similar.
What’s the possible reason for that?