Mapping between TorchDynamo compiled functions and FX GraphModules
Problem Description
I’m trying to understand the relationship between TorchDynamo compiled functions (like __compiled_fn_1
, __resume_at_30_2
) and their corresponding FX GraphModules. When working with TorchDynamo, I need to know which graph corresponds to which function, especially when control flow causes multiple graph breaks.
Code Example
import os
os.environ["TORCH_LOGS"] = "+graph_sizes,graph_breaks"
from typing import List
import torch
from torch import _dynamo as torchdynamo
compiled_functions = []
gms = []
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
for key in globals():
if (key.startswith("__compiled") or key.startswith("__resume")) and key not in compiled_functions:
compiled_functions.append(key)
print(f"New function added: {key}")
gms.append(gm)
print(gm.code.strip())
return gm.forward
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(1):
toy_example(torch.randn(10), torch.randn(10))
Output
def forward(self, L_a_ : torch.Tensor, L_b_ : torch.Tensor):
l_a_ = L_a_
l_b_ = L_b_
abs_1 = torch.abs(l_a_)
add = abs_1 + 1; abs_1 = None
x = l_a_ / add; l_a_ = add = None
sum_1 = l_b_.sum(); l_b_ = None
lt = sum_1 < 0; sum_1 = None
return (x, lt)
New function added: __compiled_fn_1
New function added: __resume_at_30_2
New function added: __resume_at_38_3
def forward(self, L_x_ : torch.Tensor, L_b_ : torch.Tensor):
l_x_ = L_x_
l_b_ = L_b_
mul = l_x_ * l_b_; l_x_ = l_b_ = None
return (mul,)
Question
I notice that TorchDynamo adds all three compiled functions (__compiled_fn_1
, __resume_at_30_2
, __resume_at_38_3
) to the global namespace during the initial bytecode analysis, before all the corresponding graph modules are even created.
How can I determine which FX GraphModule corresponds to which compiled function name? This is especially important when the execution path depends on the input data (like the conditional in my example).
Is there an API or debugging feature in TorchDynamo that exposes this mapping? If not, what would be the recommended approach to track this relationship for arbitrary code?
Environment
- PyTorch version: 2.3.0
- OS: Ubuntu 22.04
- Python version: 3.10.0