Flat model grads, CUDA graph capture, and grad-copy-elision

This question is a bundle of a few things I’ve struggled to answer from the docs.

At the top level, I am interested in approaches to maintaining “flat grads” – ie, the model gradients are a single contiguous buffer (of the same dtype as model params), and the .grad attribute for each param is a pointer into that contiguous buffer. Afaict, this is still possible (if not exactly encouraged) by manually mutating .grad to be a view tensor like so:

m = torch.nn.Linear(hidden, hidden).cuda()
buf = torch.zeros(m.weight.numel(), dtype=m.weight.dtype, device=m.weight.device)
m.weight.grad = buf[0:m.weight.numel()].view_as(m.weight)
# And so on for all the the parameters
# After running fwd + bwd, we have:
assert buf.untyped_storage().data_ptr() == m.weight.grad.untyped_storage().data_ptr()

Two partially-related questions related to this technique:

  1. How does this interact with cuda graph capture – and more generally, is it possible to reason at all about the pointers in a captured cuda graph? By definition, the output .grad of each gradient-computing kernel needs to be stable across graph executions – so if I do something like the above, will it capture the pointer that is a reference-to-buf on the first run and use it in subsequent runs? Or does the graph capture need to use its own internal memory pool to allocate the grad attributes?

  2. There is (afaik) a small optimization when calling zero_grad(set_to_none=True) where the computed gradient doesn’t need to be accumulated during backprop but can instead be stored (saving an extra memory round trip) – I heard this referred to as “grad copy elision”. Is there any way to convince autograd to perform this same optimization in the flat grads case defined above, where I want to ensure the grad pointer remains unchanged, but I don’t care about accumulating it?

Thanks!

In the whole-network capture example the gradients are explicitly deleted via set_to_none=True with the comment:

Sets grads to None before capture, so backward() will create .grad attributes with allocations from the graph’s private pool

which points towards your last statement. A bit further down you are then seeing the g.replay which might look a bit weird since the zero_grad call wasn’t captures. However, again the comment mentions:

replay() includes forward, backward, and step.
You don’t even need to call optimizer.zero_grad() between iterations
because the captured backward refills static .grad tensors in place.

Based on this I would assume deleting the gradients could break the capture. However, it would still be interesting to see if you could “bake in” the flattened gradient buffers without deleting them in the capture or if this would directly break and raise errors.