I would like to do an in-place multiplication of tensors as follows:
Q
is a tensor of size bxnxn
(We have b
batches of different Q
's)
T
is a tensor of size nxn
For each batch, I would like to perform T*Q
. My current method is to first do T = T.repeat(b,1,1)
and then do T*Q
. This seems to be wasteful as the copies of T
take up valuable space in GPU memory. Is there a better way to do this?
You might want to avoid explicitly calling repeat
on T
and let PyTorch use broadcasting as it would lower the memory usage as seen here:
import torch
print(torch.cuda.memory_allocated()/1024**2) # 0
print(torch.cuda.max_memory_allocated()/1024**2) # 0
b, n = 1024, 1024
Q = torch.randn(b, n, n, device='cuda') # 4096 GB
T = torch.randn(1, n, n, device='cuda') # 4 MB
print(torch.cuda.memory_allocated()/1024**2) # 4100 MB
print(torch.cuda.max_memory_allocated()/1024**2) # 4100 MB
#T = T.repeat(b, 1, 1) # 4GB
#print(torch.cuda.memory_allocated()/1024**2) # 8192 MB
#print(torch.cuda.max_memory_allocated()/1024**2) # 8196 MB
y = T*Q # 4096 MB
print(torch.cuda.memory_allocated()/1024**2) # 8196 MB / 12288 MB with repeat
print(torch.cuda.max_memory_allocated()/1024**2) # 8196 MB / 12288 MB with repeat
You can reuse this code to check the memory usage for your use case and compare both approaches.