Hi,
I’m trying to compile my model and use its output to calculate divergence.
I’m aware that double backward is not compatible with torch.compile. Thus, I separated the forward and the divergence calculation and even put torch.compiler.disable on the latter. No luck so far, though.
Could someone please explain why torch.compile is trying to compile the div function below even when its supposed to be disabled?
class Foo:
def __init__(self):
self.model = torch.nn.Linear(5, 5)
self.data = [torch.rand(5, requires_grad=True) for _ in range(5)]
@torch.compile
def foor(self, i):
output = self.model(self.data[i])
return output
@torch.compiler.disable
def div(foo, output):
(dx,) = torch.autograd.grad(
output,
foo.data[1],
grad_outputs=torch.ones_like(output),
create_graph=True,
)
(dx2,) = torch.autograd.grad(
dx,
foo.data[1],
grad_outputs=torch.ones_like(dx),
create_graph=True,
)
return dx2
if __name__ == "__main__":
foo = Foo()
# warmup
o1 = foo.foor(0)
output = foo.foor(1)
dx2 = div(foo, output)
Error message:
Traceback (most recent call last):
File "/main/torch_autograd/test.py", line 55, in <module>
dx2 = div(foo, output)
^^^^^^^^^^^^^^^^
File "/nnvenv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/main/torch_autograd/test.py", line 39, in div
(dx2,) = torch.autograd.grad(
^^^^^^^^^^^^^^^^^^^^
File "/nnvenv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 496, in grad
result = _engine_run_backward(
^^^^^^^^^^^^^^^^^^^^^
File "/nnvenv/lib/python3.11/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nnvenv/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/nnvenv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1728, in backward
raise RuntimeError(
RuntimeError: torch.compile with aot_autograd does not currently support double backward