Hi all- just want to check I understand the memory consumption of some code that involves broadcasting; it increases by an order of magnitude when I request gradients.
Using a gpu,
device = torch.device('cuda:0')
I build two “batched” tensors and a fixed tensor to apply to them:
d0, d1 = 100, 1000
batch_size = 800
x_ng = torch.randn(batch_size, d0, 1, requires_grad=False).to(device)
x_g = torch.randn(batch_size, d0, 1, requires_grad=True).to(device)
U = torch.randn(1, d1, d0).to(device)
Then I check gpu memory allocation for performing a matmul - it spikes when a broadcast is required with an input requiring gradients:
check_mem = lambda s : print(f"{s}: {torch.cuda.memory_allocated()/1e6:.3f} MB")
check_mem("Before computation")
#requires broadcasting U, no grads
y_ng = torch.matmul(U, x_ng)
check_mem("After no-grad matmul")
del y_ng
#requires broadcasting U, grads
y_g = torch.matmul(U, x_g)
check_mem("After grad matmul")
del y_g
# no broadcast, grads
y_g = torch.matmul(U.squeeze(), x_g)
check_mem("After grad matmul (no broadcast)")
del y_g
with the result
Before computation: 8.923 MB
After no-grad matmul: 12.123 MB
After grad matmul: 332.123 MB
After grad matmul (no broadcast): 12.123 MB
Here’s what I think is going on: when grads are required, pytorch needs to store the tensor multiplying x, and this cached version is taken after the broadcast - hence the large size. Is this accurate? Note that the size of the output tensor is the same in each case.
I don’t see anything related in the broadcasting docs.