- H ow is
AOTAutograd
able to trace through torch framework code that lives in C++ (decomp, functionalization, etc) that dynamo can’t (since it only “sees” Python bytecode)?
The short answer is __torch_dispatch__
. If you want some code pointers: we have another API make_fx
, which is effectively a tracer that uses __torch_dispatch__
to capture every ATen op that gets executed into a graph (including ATen ops that are executed in C++, e.g. through the autograd engine). make_fx
lives here, called by AOTAutograd here).
If you want to naively capture a graph of the forward and backward (without all of the bells and whistles of functionalizing the graph, or safety from dynamo guards), make_fx
is a relatively light-weight tracer that you can use to do that:
# captures and prints a graph of the combined forward and backward, tracing through the C++ autograd engine
import torch
from torch.fx.experimental.proxy_tensor import make_fx
def f(x, w):
return torch.matmul(x, w).sum().backward()
x = torch.randn(64, 64)
w = torch.randn(64, 64, requires_grad=True)
fx_graph = make_fx(f)(x, w)
fx_graph.print_readable()
- Why is there a need for
compiled_autograd
in addition to AOTAutograd
? I.e., why aren’t AccumulateGrad
nodes able to be included in the compiled backwards graph?
The answer here (at least my interpretration) is something like: torch.compile
captures almost everything in the backward that is worth compiling, but there are a few distinct things in the backward that are difficult to capture, where having dynamo run on the backward graph is helpful:
- user backward hooks. in particular, backward hooks involve tracing through user code. This code might involve specializing on globals, or performing python side effects that we can’t safely graph capture, and so we need dynamo to (somehow) be in the loop.
- AccumulateGrad. There are a few things that make capturing AccumulateGrad difficult (naively, without compiled autograd) , mainly around handling all of the edge cases: handling post accumulate grad hooks, or ensuring that we properly guard on any behavior that AccumulateGrad specializes on
- slightly better graph break behavior. Normally if there are two graphs in the forward, compile generates two graphs in the backward. Compiled autograd can instead ensure that we get a single graph for the entire backward, even if there are graph breaks in the forward.
Compiled autograd morally works by taking the existing autograd tape that lives in the C++ autograd engine, and retracing all of it into a python FX graph that we can then run dynamo on and compile, which handles all of those cases nicely.
- How do each of these components –
dynamo
, AOTAutograd
, compiled_autograd
, and other parts of the PT2 stack – interact to enable tracing and compilation of larger, more representative graphs?.. What are the limitations of each component that necessitates the others and at what level are they operating (Python, C++)?
I might be missing part of your question. But the short answer is that they are all (complementary) tracing systems.
Dynamo has the advantage that it can handle arbitrary user python code, guarding on any implicitly specialized code, and falling back to cpython when necessary, and emits a graph full of torch.*
IR. AOTAutograd is a “dumb” tracer that can trace through arbitrary code, and also further trace through the C++ dispatcher to emit a graph of aten.*
IR.
Dynamo first generates a graph of torch ops (falling back to cpython and guarding on any specializations along the way), and after proving safety it sends this torch IR graph to AOTAutograd to be lowered further, also generating a backward graph. That backward graph gets inserted as a CompiledFunctionBackward
node into the eager autograd engine’s tape.
Compiled autograd is effectively an extra layer that runs at the time a user calls .backward()
. Instead of letting the eager autograd engine execute each node in the tape eagerly, it traces the autograd tape itself back into an FX graph, running dynamo (and the rest of the stack) on the resulting FX graph.
- For example, when compiling
FSDP
-wrapped models, how do each of these parts fit together?
Hmm i’m not sure if I’m giving you a full answer. But FSDP is effectively a piece of python framework code, so the main differences in the PT2 stack around FSDP handling are mostly in dynamo. Dynamo will graph break on bits of FSDP that are difficult to capture. On thing worth noting though is that FSDP is implemented partially with autograd hooks, and so compiled autograd is one of the tools that we need to capture them properly