Decomposition to Aten IR

@bdhirsh

Trying to understand how the decomposition of torch.nn.functional.linear is traced.

Example code:

import torch
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx

device = "cuda"
x = torch.randn(10, 10).to(device)
y = torch.randn(10, 10).to(device)

g: GraphModule = make_fx(torch.nn.functional.linear, select_decomp_table(), tracing_mode="real")(x, y)
g.print_readable()

This prints

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 10]"):
        # No stacktrace found for following nodes
        permute: "f32[10, 10]" = torch.ops.aten.permute.default(arg1_1, [1, 0]);  arg1_1 = None
        mm: "f32[10, 10]" = torch.ops.aten.mm.default(arg0_1, permute);  arg0_1 = permute = None
        return mm

I’ve tried to step through the code using pdb but still can’t figure out how the decomposition of linear is happening:

  • where is the mapping of linear to permute and mm?
  • if there is no direct mapping, how can I determine where along the trace these registered ops are being called?

Noticed that something similar is called during the pregrad passes that inductor runs to normalize the IR into something more amenable for downstream passes.

Also related are the AOTAutograd slides from ASPLOS 2024. On slide 23, it is mentioned that AOTAutograd handles “operator decomposition” among other Pytorch framework related functions that occur at the C++ level, and it does this through __torch_dispatch__.

Where / how does this call back to Python happen after the dispatching, autograd, and decomposition work is handled in C++?

Thanks!

-Jerome

Hey @Jerome_Ku. The short answer is:

We have a handful of operator decompositions that always run by-default, inside of the dispatcher (in C++), before making it into __torch_dispatch__.

linear/matmul is by far the most common

If you’re wondering why, the historical answer is that there are a number of ops that we don’t have dedicated derivative formulas for (e.g. linear), and so rather than writing a brand new formula, we just have the autograd engine decompose the op into more primitive ops that it does have formulas for (e.g. aten.mm and transpose).

If you are interested in export and only care about inference, we actually recently made it so that exporting for inference can preserve all ATen ops, including these special ops like linear, in the graph:

m = torch.nn.Linear(...)
graph_module = torch.export.export(m, args).run_decompositions().module()
# you should see aten.linear, as long as you didn't manually specify that you want it decomposed
print(graph_module)

@bdhirsh

Thanks for the response!

Trying to get a better mental picture of how the PT2 stack fits together in the context of more complex use cases such as tensor subclasses (quantized types, DTensor) and distributed training (FSDP).

Hoping you can help shed some light on the following or provide some pointers where to dig deeper (I’ve reviewed the relevant presos from the most recent Pytorch conf as well as the ASPLOS tutorial):

  • How is AOTAutograd able to trace through torch framework code that lives in C++ (decomp, functionalization, etc) that dynamo can’t (since it only “sees” Python bytecode)?
  • Why is there a need for compiled_autograd in addition to AOTAutograd? I.e., why aren’t AccumulateGrad nodes able to be included in the compiled backwards graph?
  • How do each of these components – dynamo, AOTAutograd, compiled_autograd, and other parts of the PT2 stack – interact to enable tracing and compilation of larger, more representative graphs?
  • What are the limitations of each component that necessitates the others and at what level are they operating (Python, C++)?
  • For example, when compiling FSDP-wrapped models, how do each of these parts fit together?

Many thanks!

  • H ow is AOTAutograd able to trace through torch framework code that lives in C++ (decomp, functionalization, etc) that dynamo can’t (since it only “sees” Python bytecode)?

The short answer is __torch_dispatch__. If you want some code pointers: we have another API make_fx, which is effectively a tracer that uses __torch_dispatch__ to capture every ATen op that gets executed into a graph (including ATen ops that are executed in C++, e.g. through the autograd engine). make_fx lives here, called by AOTAutograd here).

If you want to naively capture a graph of the forward and backward (without all of the bells and whistles of functionalizing the graph, or safety from dynamo guards), make_fx is a relatively light-weight tracer that you can use to do that:

# captures and prints a graph of the combined forward and backward, tracing through the C++ autograd engine
import torch
from torch.fx.experimental.proxy_tensor import make_fx

def f(x, w):
    return torch.matmul(x, w).sum().backward()

x = torch.randn(64, 64)
w = torch.randn(64, 64, requires_grad=True)
fx_graph = make_fx(f)(x, w)
fx_graph.print_readable()
  • Why is there a need for compiled_autograd in addition to AOTAutograd? I.e., why aren’t AccumulateGrad nodes able to be included in the compiled backwards graph?

The answer here (at least my interpretration) is something like: torch.compile captures almost everything in the backward that is worth compiling, but there are a few distinct things in the backward that are difficult to capture, where having dynamo run on the backward graph is helpful:

  • user backward hooks. in particular, backward hooks involve tracing through user code. This code might involve specializing on globals, or performing python side effects that we can’t safely graph capture, and so we need dynamo to (somehow) be in the loop.
  • AccumulateGrad. There are a few things that make capturing AccumulateGrad difficult (naively, without compiled autograd) , mainly around handling all of the edge cases: handling post accumulate grad hooks, or ensuring that we properly guard on any behavior that AccumulateGrad specializes on
  • slightly better graph break behavior. Normally if there are two graphs in the forward, compile generates two graphs in the backward. Compiled autograd can instead ensure that we get a single graph for the entire backward, even if there are graph breaks in the forward.

Compiled autograd morally works by taking the existing autograd tape that lives in the C++ autograd engine, and retracing all of it into a python FX graph that we can then run dynamo on and compile, which handles all of those cases nicely.

  • How do each of these components – dynamo, AOTAutograd, compiled_autograd, and other parts of the PT2 stack – interact to enable tracing and compilation of larger, more representative graphs?.. What are the limitations of each component that necessitates the others and at what level are they operating (Python, C++)?

I might be missing part of your question. But the short answer is that they are all (complementary) tracing systems.

Dynamo has the advantage that it can handle arbitrary user python code, guarding on any implicitly specialized code, and falling back to cpython when necessary, and emits a graph full of torch.* IR. AOTAutograd is a “dumb” tracer that can trace through arbitrary code, and also further trace through the C++ dispatcher to emit a graph of aten.* IR.

Dynamo first generates a graph of torch ops (falling back to cpython and guarding on any specializations along the way), and after proving safety it sends this torch IR graph to AOTAutograd to be lowered further, also generating a backward graph. That backward graph gets inserted as a CompiledFunctionBackward node into the eager autograd engine’s tape.

Compiled autograd is effectively an extra layer that runs at the time a user calls .backward(). Instead of letting the eager autograd engine execute each node in the tape eagerly, it traces the autograd tape itself back into an FX graph, running dynamo (and the rest of the stack) on the resulting FX graph.

  • For example, when compiling FSDP-wrapped models, how do each of these parts fit together?

Hmm i’m not sure if I’m giving you a full answer. But FSDP is effectively a piece of python framework code, so the main differences in the PT2 stack around FSDP handling are mostly in dynamo. Dynamo will graph break on bits of FSDP that are difficult to capture. On thing worth noting though is that FSDP is implemented partially with autograd hooks, and so compiled autograd is one of the tools that we need to capture them properly

1 Like

@bdhirsh

This is a fantastic explanation – much appreciated.

A quick follow-up: how does AOTAutograd use __torch_dispatch__ to trace the backwards graph? Does it unroll the autograd tape that is recorded from the forward graph in something like FakeTensorMode and generate an fx graph from the aten ops that __torch_dispatch__ sees (at a Python level)?

It’s actually even simpler than that - we are effectively running the autograd engine as-is, except every operation that gets called ends up hitting our __torch_dispatch__ handler before we need to run its implementation.

A simple example might make things more clear. Below is a simple torch_dispatch mode that just prints every op it sees.

import torch
from torch.utils._python_dispatch import TorchDispatchMode

class LoggingMode(TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        rs = func(*args, **kwargs)
        print(f'func={str(func)}')
        return rs

def f(x):
    out = x.sin()
    out.sum().backward()

x = torch.randn(4, requires_grad=True)
with LoggingMode():
    f(x)

If you run it, you’ll see that both aten.sin and aten.cos are printed. This maps to the fact that autograd engine has a SinBackward node saved somewhere in the autograd tape - and upon running the backward, autograd will execute that node, which eventually calls aten.cos (the derivative), which we intercept:

func=aten.sin.default
func=aten.sum.default
func=aten.ones_like.default
func=aten.expand.default
func=aten.cos.default
func=aten.mul.Tensor
func=aten.detach.default
func=aten.detach.default

in the make_fx case, our __torch_dispatch__ code is effectively intercepting each op and adding it to an FX graph (code here)

1 Like

@bdhirsh

Awesome – thanks for the great explanation. All clear now!