While implementing the batched matrix multiplication, i noticed that the batched matrix multiplication is not efficient, see the code below
# Input tensor
## Batch size=8192, dim=512
x = torch.FloatTensor(8192, 512).requires_grad_().cuda()
# Batch strategy 1
x1 = x.view(8192, 8, 1, 64) # 512 = 8 * 64
W1 = torch.FloatTensor(8, 64, 64).cuda()
out1 = torch.matmul(x1, W1) # out: [8192, 8, 1, 64]
print(torch.cuda.memory_allocated()) # 1107427328
# Batch strategy 2
x2 = x.view(8192, 1, 512) # add one dimension for batch matmul
W2 = torch.FloatTensor(512, 512).cuda() # larger than W1
# out: [8192, 1, 512] # the same number of elements as out1
out2 = torch.matmul(x2, W2)
print(torch.cuda.memory_allocated()) # 34603008
However, it turns out that Batch strategy 2 has less memory cost despite that W2 is larger than Batch strategy 1. And everything else are the same (x1, x2 have same number of elements, also out1, out2).
I also found that by removing the requires_grad_() the memory costs are similar.
What’s the possible reason for that?
If you look at what matmul does (it’s in C++ but you can directly transpose it into Python) you see that there are a number of reshaping / broadcasting ops involved. As matmul does not have a custom derivative (you can see this in tools/autograd/derivatives.yaml), the backward is done by keeping track of the operations performed and the inputs required for the backward. Quite likely, some of the intermediate results are cached for the backward.
If it’s excessive, you could file a bug to implement an explicit backward for matmul.
An alternative could be to try einsum and see if it is better about it (but I’m not sure it is).
Hi Tomas, thanks very much for your suggestions!
I can imagine that matmul involves many operations for this special case, and after trying einsum the memory cost is comparable to Batch strategy 2 now. It is really surprising!
For people who encounter a similar problem, here is what I did
# Batch strategy 1' ===> optimizes Batch strategy 1
x1 = x.view(8192, 8, 64) # 512 = 8 * 64
W1 = torch.FloatTensor(8, 64, 64).normal_().cuda()
out1 = torch.einsum('bij,abj->abi', (W1, x1))