Stitching together graph breaks for large "compilation units"

Hi all! This is my first post here and I’m looking forward to learning more about PyTorch internals.

I have a large model that I’m running through torch.compile and aot_autograd. The model has various graph breaks, mostly due to calls to torch.nonzero. For small models it’s pretty clear how to “stitch” these graphs together and run the entire compiled flow for both the forward and backward passes. But for large models this process is much more error prone – I sometimes have hundreds of intermediate tensors to marshal between each graph and its corresponding backward graph and it’s not always clear which tensor each “tangent” corresponds to because I can have multiple tangents of the same shape and dtype.

Is there a way I can use PyTorch itself to untangle this mess for me automatically? In a perfect world I would have a Python function that given input arguments to the model shows me how to 1) invoke each forward graph, 2) what operations need to occur between graphs (if any), and 3) how to invoke each backward graph given the forward graph results. This is essentially what PyTorch does when someone invokes the compiled model (and autograd, I guess) but trying to replicate that process through the debugger seems hopeless because of all the stack layers one would have to wade through.

Thanks in advance for any tips here!

1 Like

For small models it’s pretty clear how to “stitch” these graphs together and run the entire compiled flow for both the forward and backward passes

Stitching subraphs graphs together + figuring out when it’s safe to re-use them or not (vs. recompile a new graph) is pretty non-trivial, and this is mainly what dynamo is in charge of in the first place - if you have a graph break, that technically means that there can be arbitrary python code running in between your two graphs, so there’s no trivial relationship between the outputs of one graph and the inputs of another in all cases. What’s the motivation for trying to do stitch these graph together yourself instead of letting dynamo handle it? (I would probably focus instead of figuring out where the graph breaks are coming from - there are some tools to help make this easier)

mostly due to calls to torch.nonzero

You can set the following config to capture nonzero in the graph (we should probably include this in the error message… although it’s possible this config will disappear at some point and capturing nonzero will become the default at some point):

torch._dynamo.config.capture_dynamic_output_shape_ops = True

For example, with this script:

import torch

torch._dynamo.config.capture_dynamic_output_shape_ops = True

def f(x):
    y = torch.nonzero(x)
    z = torch.ones(y.shape)
    return z.sum() + y.sum()

x = torch.ones(4)
out = f(x)

When I run the following command, I get a single graph with no graph breaks

TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" TORCH_LOGS="aot" python


 ===== Forward graph 0 =====
 <eval_with_key>.2 from /data/users/hirsheybar/e/pytorch/torch/fx/experimental/ in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[4]"):
        # File: /data/users/hirsheybar/e/pytorch/, code: y = torch.nonzero(x)
        nonzero: "i64[i1, 1]" = torch.ops.aten.nonzero.default(arg0_1);  arg0_1 = None

        # File: /data/users/hirsheybar/e/pytorch/, code: z = torch.ones(y.shape)
        sym_size_int: "Sym(i1)" =, 0)
        ones: "f32[i1, 1]" = torch.ops.aten.ones.default([sym_size_int, 1], device = device(type='cpu'), pin_memory = False);  sym_size_int = None

        # File: /data/users/hirsheybar/e/pytorch/, code: return z.sum() + y.sum()
        sum_1: "f32[]" = torch.ops.aten.sum.default(ones);  ones = None
        sum_2: "i64[]" = torch.ops.aten.sum.default(nonzero);  nonzero = None
        add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2);  sum_1 = sum_2 = None
        return (add,)

Thanks a ton for your detailed response, Brian.

Great question. I’m writing a compiler that uses FxGraphs as an intermediate representation and generates code that will execute outside of Python. torch.export is really ideal for this use case, but torch.export does not support autograd and my application requires autograd for inference.

The next best solution would be to use torch.compile with fullgraph=True. This approach will successfully compile the forward pass (indeed despite usage of torch.nonzero, as you show) but runs into GuardOnDataDependentSymNode during the backward pass compilation. It’s unclear to me if this failure is expected for my input or if it results from an issue in AOTAutograd, so I filed an issue to try and learn more: torch.compile of backward pass via aot_autograd (with dynamic and fullgraph) encounters GuardOnDataDependentSymNode during backward compile · Issue #116703 · pytorch/pytorch · GitHub.

So in the meantime I was hoping I could live with multiple graphs via torch.compile with fullgraph=False until either the above issue is resolved or torch.export supports my use case. But from your response it sounds like that might be a bad idea. I actually was able to figure out how to stitch the forward pass for my model but the backward pass is quite complicated.

Right – since dynamo is able to execute this procedure I was hoping there might be a way for dynamo to present it to the user in the form of a Python class or function. Sounds like that functionality doesn’t exist (and it probably isn’t worth the effort to create it).

Yep - my guess is that your effort would be better spent trying to get your model into a single graph, rather than figure out a way to hand-stitch the multiple graphs together (which might be doable for your specific use case, but in general is difficult to do in a safe way).

Do you have a runnable repro / can point to the user code that is causing that exception?

Suppose you have code like this:

def f(x):
    # y's shape is data-dependent
    y = x.nonzero()
    # normally this will error, since we don't know statically what y.shape[0] is (data-dependent)
    if y.shape[0] > 1:
        return y.sin()
        return y.cos()

One thing you can try is adding torch._check(y.shape[0] > 1), if you know that for your model, it’s ok to burn in one of the two branches.

If you really need a single graph that can potentially handle both cases, in theory you should be able to do something like this:

return torch.cond(y.shape[0] > 1, true_fn, false_fn)

Although I’m not too sure what the state is of cond() where the condition contains (unbacked) dynamic shapes. It’s probably worth filing a (minimal) repro github issue if you run into problems with it.

The problem I’m having is that the compilation fails in the backward pass, so there isn’t an obvious spot to point to that is causing the exception. Here’s a minimal example:

import torch

from typing import List
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func

torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.capture_scalar_outputs = True

def forward(x):
    batch_size = x.size()[0]
    molecule_size = x.size()[1]
    edges = torch.nonzero(x > 0.5, as_tuple=True)
    index_ij = ((edges[0] * molecule_size * molecule_size) + (edges[1] * molecule_size) + edges[2])
    dist_x = (x.unsqueeze(1) - x.unsqueeze(2)).sum(3)
    dist_indexed = dist_x[index_ij]
    return dist_indexed

graph_modules = []
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    return make_boxed_func(gm.forward)

my_compiler = aot_autograd(fw_compiler=my_compiler,bw_compiler=my_compiler)

forward_aot = torch.compile(forward, fullgraph=True, dynamic=True, backend=my_compiler)

device = torch.device("cuda")

x = torch.rand([1, 5, 3], device=device)
x = x.requires_grad_(True)

vr_binned_x_ij = forward_aot(x)

I think that the gradient of dist_x[index_ij] is the problem. I guess in FxGraph speak this would mean I have a torch.ops.aten.index.Tensor operation in the forward pass that, when differentiated, turns into a torch.ops.aten.index_put.default operation in the backward pass.

Since FxGraphs require out of place operations I’ve seen something like torch.ops.aten.new_zeros.default followed by torch.ops.aten.index_put.default in the backward pass. Somehow this results in two symbolic ints that dynamo wants to test equality for, but can’t:

GuardOnDataDependentSymNode: It appears that you’re trying to get a value out of symbolic int/float whose value is data-dependent (and thus we do not know the true value.) The expression we were trying to evaluate is Eq(i2, i1) (unhinted: Eq(i2, i1)). Scroll up to see where each of these data-dependent accesses originally occurred.