Error when backpropagating through a compiled PyTorch module multiple times

Hi!

This is actually a semi-fundamental issue with torch.compile. In particular:

(1) in eager mode: this works, because the backward graphs for q and k are completely independent. You can therefore .backward() on each of them separately.

(2) in compile: By compiling your entire AxialRoPE module into a single compiled region, torch.compile is free to try to fuse the backward compute of q and v together. This the “compiled” version of the backward graph to show up as one giant node in the autograd graph, instead of a bunch of tiny autograd noes that can be run independently.

So the tldr here is that the benefits of compilation / fusion mean that we lose some of the fine-grained detail of the autograd graph.

You effectively have two options here:

Option 1: Run with

(q + k).sum().backward()

To perform all backward compute in a single call, This will allow us to compute gradients for q and v simulatenously (especially good for perf if torch.compile has fused their backward compute together)

Option 2: run with:

q.sum().backward(retain_graph=True)
k.sum().backward()

This will allow autograd to keep the compiled backward graph around after the first backward call (at the cost of extra memory). In particular; the compiled backward code is allowed to fuse backward compute for q and v together. So you might end up doing some redundant computation if you call .backward() twice

option 3:

Explicitly separate your logic for q’s compute and v’s compute into separate regions of code, so you can compile them separately. This will effectively force torch.compile not to bundle any of the compute for q’s backward and v’s backward together, so you can call .backward() separately for each output

1 Like