Mapping between TorchDynamo compiled functions and FX GraphModules

Mapping between TorchDynamo compiled functions and FX GraphModules

Problem Description

I’m trying to understand the relationship between TorchDynamo compiled functions (like __compiled_fn_1, __resume_at_30_2) and their corresponding FX GraphModules. When working with TorchDynamo, I need to know which graph corresponds to which function, especially when control flow causes multiple graph breaks.

Code Example

import os
os.environ["TORCH_LOGS"] = "+graph_sizes,graph_breaks"

from typing import List
import torch
from torch import _dynamo as torchdynamo

compiled_functions = []
gms = []

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    for key in globals():
        if (key.startswith("__compiled") or key.startswith("__resume")) and key not in compiled_functions:
            compiled_functions.append(key)
            print(f"New function added: {key}")

    gms.append(gm)
    print(gm.code.strip())
    return gm.forward

@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

for _ in range(1):
    toy_example(torch.randn(10), torch.randn(10))

Output

def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
    l_a_ = L_a_
    l_b_ = L_b_
    abs_1 = torch.abs(l_a_)
    add = abs_1 + 1;  abs_1 = None
    x = l_a_ / add;  l_a_ = add = None
    sum_1 = l_b_.sum();  l_b_ = None
    lt = sum_1 < 0;  sum_1 = None
    return (x, lt)
New function added: __compiled_fn_1
New function added: __resume_at_30_2
New function added: __resume_at_38_3
def forward(self, L_x_ : torch.Tensor, L_b_ : torch.Tensor):
    l_x_ = L_x_
    l_b_ = L_b_
    mul = l_x_ * l_b_;  l_x_ = l_b_ = None
    return (mul,)

Question

I notice that TorchDynamo adds all three compiled functions (__compiled_fn_1, __resume_at_30_2, __resume_at_38_3) to the global namespace during the initial bytecode analysis, before all the corresponding graph modules are even created.

How can I determine which FX GraphModule corresponds to which compiled function name? This is especially important when the execution path depends on the input data (like the conditional in my example).

Is there an API or debugging feature in TorchDynamo that exposes this mapping? If not, what would be the recommended approach to track this relationship for arbitrary code?

Environment

  • PyTorch version: 2.3.0
  • OS: Ubuntu 22.04
  • Python version: 3.10.0

I’m curious about why you want to know this information. Are you writing a custom backend?

Yes, I’m writing custom backend. Is there good way to know which graph module corresponds to which function?