Torch.compile emits double backward error even when the corresponding function is disabled

Hi,

I’m trying to compile my model and use its output to calculate divergence.

I’m aware that double backward is not compatible with torch.compile. Thus, I separated the forward and the divergence calculation and even put torch.compiler.disable on the latter. No luck so far, though.

Could someone please explain why torch.compile is trying to compile the div function below even when its supposed to be disabled?

class Foo:
    def __init__(self):
        self.model = torch.nn.Linear(5, 5)
        self.data = [torch.rand(5, requires_grad=True) for _ in range(5)]

    @torch.compile
    def foor(self, i):
        output = self.model(self.data[i])
        return output


@torch.compiler.disable
def div(foo, output):
    (dx,) = torch.autograd.grad(
        output,
        foo.data[1],
        grad_outputs=torch.ones_like(output),
        create_graph=True,
    )
    (dx2,) = torch.autograd.grad(
        dx,
        foo.data[1],
        grad_outputs=torch.ones_like(dx),
        create_graph=True,
    )
    return dx2


if __name__ == "__main__":
    foo = Foo()
    # warmup
    o1 = foo.foor(0)

    output = foo.foor(1)

    dx2 = div(foo, output)

Error message:

Traceback (most recent call last):
  File "/main/torch_autograd/test.py", line 55, in <module>
    dx2 = div(foo, output)
          ^^^^^^^^^^^^^^^^
  File "/nnvenv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/main/torch_autograd/test.py", line 39, in div
    (dx2,) = torch.autograd.grad(
             ^^^^^^^^^^^^^^^^^^^^
  File "/nnvenv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 496, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/nnvenv/lib/python3.11/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nnvenv/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/nnvenv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1728, in backward
    raise RuntimeError(
RuntimeError: torch.compile with aot_autograd does not currently support double backward

If I understand correctly you’d like to compile the forward and then run the backward in eager mode. This is tricky to do because in order to run the backward eagerly, you needed the autograd graph to be built, which longer happens in the way you want because compile collapses the entire forward compute into a single mega node.

Hi @soulitzer,

Thanks for your reply!

Please correct me if I’m wrong, but am I not only compiling the forward in the above code?

The foor function, which contains the model’s forward, is torch.compiled. After running that function, then I use its output as the input parameter of the div function which contains a double backward.

I thought those two operations (foor and div) are separated, making only the former JIT-compiled.

Am I wrong?

Oh hm…
I guess the above error isn’t caused by div getting torch.compiled (I thought torch.compile is somehow still trying to compile the div function). It’s because the gradient calculation dx2 depends on dx… which depends on the output from torch.compiled foo.foor() where its gradient is calculated AOT? And I’m guessing calculating the later part is somehow problematic??

Could you please explain why this is not torch.compile-able, @ptrblck? How come this operation (double backward) is something that torch.compile doesn’t natively support?

Sorry in advance for my lack of understanding in how JIT actually works in torch.compile. I’m trying to learn.

Your torch.compile decorator is indeed only wrapping your forward code. But if your inputs require grad, and grad mode is enabled, torch.compile traces out the backward graph ahead of time whether or not you have the context manager wrapping the code where you actually invoke backward later.

It’s not possible today to run forward in compile but backward in a non-compiled way (in eager). In order to run the backward in eager, the autograd graph needs to be built every iteration, and that won’t happen if you ran the forward compiled graph.

There are also advantages to considering the two graphs jointly, e.g. to optimize what is saved between them.

How come this operation (double backward) is something that torch.compile doesn’t natively support?

There were some efforts to do it a while back, but it was hard to do / not highest priority thing at the time, and was dropped. Feel free to comment on/ upvote the issue on double backward + compile.

1 Like

Thanks for your thorough explanation, Jeffrey!