Understanding broadcasting with gradients

Hi all- just want to check I understand the memory consumption of some code that involves broadcasting; it increases by an order of magnitude when I request gradients.

Using a gpu,

device = torch.device('cuda:0')

I build two “batched” tensors and a fixed tensor to apply to them:

d0, d1 = 100, 1000
batch_size = 800

x_ng = torch.randn(batch_size, d0, 1, requires_grad=False).to(device)
x_g = torch.randn(batch_size, d0, 1, requires_grad=True).to(device)
U = torch.randn(1, d1, d0).to(device)

Then I check gpu memory allocation for performing a matmul - it spikes when a broadcast is required with an input requiring gradients:

check_mem = lambda s : print(f"{s}: {torch.cuda.memory_allocated()/1e6:.3f} MB")

check_mem("Before computation")

#requires broadcasting U, no grads
y_ng = torch.matmul(U, x_ng)
check_mem("After no-grad matmul")
del y_ng

#requires broadcasting U, grads
y_g = torch.matmul(U, x_g)
check_mem("After grad matmul")
del y_g

# no broadcast, grads
y_g = torch.matmul(U.squeeze(), x_g)
check_mem("After grad matmul (no broadcast)")
del y_g

with the result

Before computation: 8.923 MB
After no-grad matmul: 12.123 MB
After grad matmul: 332.123 MB
After grad matmul (no broadcast): 12.123 MB

Here’s what I think is going on: when grads are required, pytorch needs to store the tensor multiplying x, and this cached version is taken after the broadcast - hence the large size. Is this accurate? Note that the size of the output tensor is the same in each case.

I don’t see anything related in the broadcasting docs.