Memory inefficient in batch matrix multiplication (with autograd)

While implementing the batched matrix multiplication, i noticed that the batched matrix multiplication is not efficient, see the code below

import torch                                                                                                                                                                             
                                                                                                                                                                                         
# Input tensor                                                                                                                                                                           
## Batch size=8192, dim=512                                                                                                                                                             
x = torch.FloatTensor(8192, 512).requires_grad_().cuda()                                                                                                                                 
                                                                                                                                                                                         
if True:                                                                                                                                                                                 
    # Batch strategy 1                                                                                                                                                                   
    x1 = x.view(8192, 8, 1, 64)  # 512 = 8 * 64                                                                                                                                        
    W1 = torch.FloatTensor(8, 64, 64).cuda()                                                                                                                                             
    out1 = torch.matmul(x1, W1)  # out: [8192, 8, 1, 64]                                                                                                                                
                                                                                                                                                                                         
    print(torch.cuda.memory_allocated())   # 1107427328                                                                                                                                                 
                                                                                                                                                                                         
if False:                                                                                                                                                                                
    # Batch strategy 2                                                                                                                                                                   
    x2 = x.view(8192, 1, 512)  # add one dimension for batch matmul                                                                                                                      
    W2 = torch.FloatTensor(512, 512).cuda()  # larger than W1                                                                                                                            
    # out: [8192, 1, 512] # the same number of elements as out1                                                                                                                                   
    out2 = torch.matmul(x2, W2)                                                                                                                                                          
    print(torch.cuda.memory_allocated())  # 34603008

However, it turns out that Batch strategy 2 has less memory cost despite that W2 is larger than Batch strategy 1. And everything else are the same (x1, x2 have same number of elements, also out1, out2).

I also found that by removing the requires_grad_() the memory costs are similar.

What’s the possible reason for that?

1 Like

If you look at what matmul does (it’s in C++ but you can directly transpose it into Python) you see that there are a number of reshaping / broadcasting ops involved. As matmul does not have a custom derivative (you can see this in tools/autograd/derivatives.yaml), the backward is done by keeping track of the operations performed and the inputs required for the backward. Quite likely, some of the intermediate results are cached for the backward.
If it’s excessive, you could file a bug to implement an explicit backward for matmul.

An alternative could be to try einsum and see if it is better about it (but I’m not sure it is).

Best regards

Thomas

2 Likes

Hi Tomas, thanks very much for your suggestions!

I can imagine that matmul involves many operations for this special case, and after trying einsum the memory cost is comparable to Batch strategy 2 now. It is really surprising!

For people who encounter a similar problem, here is what I did

# Batch strategy 1' ===> optimizes Batch strategy 1                                                                                                                                                                   
x1 = x.view(8192, 8, 64)  # 512 = 8 * 64                                                                                                                                           
W1 = torch.FloatTensor(8, 64, 64).normal_().cuda()                                                                                                                                   
out1 = torch.einsum('bij,abj->abi', (W1, x1))
2 Likes