How to obtain the variable asociation relationship of FX graph between forward and backward?

  • I have a simple code to get the variable asociation relationship of FX graph between forward and reverse, which is base on the relationship of operator types with help of deepcode.

https://chat.deepseek.com/a/chat/s/7ce16ac6-7fc6-4cbf-bc03-04ebf4636ef3

  • However, there are many limitations, such as the inability to handle
    a) duplicate operands in an fx graph
    b) many output for a operator

Does any idea to improve it ?

import re
from collections import defaultdict

class FXGraphRelationAnalyzer:
    def __init__(self, forward_code, backward_code):
        self.forward_code = forward_code
        self.backward_code = backward_code
        self.forward_nodes = self._parse_code(forward_code)
        self.backward_nodes = self._parse_code(backward_code)
        self.op_mapping = self._build_op_mapping()
        self.relations = self._build_relations()

    def _parse_code(self, code):
        nodes = {}
        lines = code.strip().split('\n')
        
        # 改进的正则表达式:匹配FX图格式
        # 示例: "view: f32[32, 784] = torch.ops.aten.view.default(x, [32, -1])"
        assign_pattern = re.compile(r'^(\w+)(?::\s*[^=]+)?\s*=\s*(\w+(?:\.\w+)*\.\w+)')
        
        print(f"\n解析代码: {code[:50]}...")
        for i, line in enumerate(lines):
            line = line.strip()
            if not line or line.startswith(('#', 'def', 'return', '@')):
                continue
                
            # 调试输出
            print(f"解析第 {i+1} 行: '{line}'")
            
            match = assign_pattern.match(line)
            if match:
                var_name, op_name = match.groups()
                print(f"  匹配成功: 变量={var_name}, 操作={op_name}")
                nodes[var_name] = {
                    'operation': op_name,
                    'full_line': line
                }
            else:
                print(f"  无法解析的行: '{line}'")
        
        print(f"解析完成,找到 {len(nodes)} 个节点")
        return nodes

    def _build_op_mapping(self):
        # 更准确的正反向操作映射
        return {
            'torch.ops.aten.view.default': 'torch.ops.aten.view_backward.default',
            'torch.ops.aten.addmm.default': 'torch.ops.aten.mm.default',
            'torch.ops.aten.relu.default': 'torch.ops.aten.threshold_backward.default',
            'torch.ops.aten.linear.default': 'torch.ops.aten.linear_backward.default',
            'torch.ops.aten.sigmoid.default': 'torch.ops.aten.sigmoid_backward.default',
            'torch.ops.aten.nll_loss_forward.default': 'torch.ops.aten.nll_loss_backward.default',
            'torch.ops.aten.t.default': 'torch.ops.aten.t_backward.default',
            'torch.ops.aten.sum.default': 'torch.ops.aten.sum_backward.default',
            'torch.ops.aten.mm.default': 'torch.ops.aten.mm_backward.default',
            'torch.ops.aten.clone.default': 'torch.ops.aten.clone_backward.default',
        }

    def _build_relations(self):
        relations = defaultdict(list)
        
        print("\n=== 前向节点分析 ===")
        for fwd_var, fwd_info in self.forward_nodes.items():
            print(f"前向变量: {fwd_var}, 操作: {fwd_info['operation']}")
            
        print("\n=== 反向节点分析 ===")
        for bwd_var, bwd_info in self.backward_nodes.items():
            print(f"反向变量: {bwd_var}, 操作: {bwd_info['operation']}")
        
        print("\n=== 关联关系构建 ===")
        for fwd_var, fwd_info in self.forward_nodes.items():
            fwd_op = fwd_info['operation']
            
            # 查找对应的反向操作
            target_bwd_op = self.op_mapping.get(fwd_op)
            if not target_bwd_op:
                print(f"警告: 没有为前向操作 {fwd_op} 找到映射的反向操作")
                continue
                
            print(f"查找 {fwd_op} 的反向操作 {target_bwd_op}")
            for bwd_var, bwd_info in self.backward_nodes.items():
                if bwd_info['operation'] == target_bwd_op:
                    print(f"  找到匹配: {fwd_var} ({fwd_op}) → {bwd_var} ({target_bwd_op})")
                    relations[fwd_var].append({
                        'backward_var': bwd_var,
                        'relationship': 'op_type_mapping',
                        'forward_op': fwd_op,
                        'backward_op': target_bwd_op
                    })
        
        # 添加额外关系:参数依赖
        for bwd_var, bwd_info in self.backward_nodes.items():
            # 在反向代码中查找前向变量名
            for fwd_var in self.forward_nodes:
                if f' {fwd_var},' in bwd_info['full_line'] or f' {fwd_var})' in bwd_info['full_line']:
                    print(f"  参数依赖: 反向操作 {bwd_var} 依赖前向变量 {fwd_var}")
                    relations[fwd_var].append({
                        'backward_var': bwd_var,
                        'relationship': 'parameter_dependency',
                        'forward_op': self.forward_nodes[fwd_var]['operation'],
                        'backward_op': bwd_info['operation']
                    })
        
        return relations

    def print_relations(self):
        print("\n===== 前向-反向变量关联关系 =====")
        if not self.relations:
            print("没有找到任何关联关系")
            return
            
        for fwd_var, rels in self.relations.items():
            print(f"\n前向变量: {fwd_var} ({self.forward_nodes[fwd_var]['operation']})")
            for rel in rels:
                relation_type = "操作类型映射" if rel['relationship'] == 'op_type_mapping' else "参数依赖"
                print(f"  → 反向变量: {rel['backward_var']} ({rel['backward_op']})")
                print(f"    关系类型: {relation_type}")

# 测试代码(使用更真实的FX图格式)
forward_code = """
def forward(self, x):
    view: f32[32, 784] = torch.ops.aten.view.default(x, [32, -1])
    addmm: f32[32, 64] = torch.ops.aten.addmm.default(w1, view, b1)
    relu: f32[32, 64] = torch.ops.aten.relu.default(addmm)
    return relu
"""

backward_code = """
def backward(self, grad_output):
    threshold_backward: f32[32, 64] = torch.ops.aten.threshold_backward.default(grad_output, 0)
    mm: f32[64, 784] = torch.ops.aten.mm.default(threshold_backward, view)
    view_backward: f32[32, 1, 784] = torch.ops.aten.view_backward.default(mm, x.shape)
    return view_backward
"""

print("="*50 + " 开始测试 " + "="*50)
analyzer = FXGraphRelationAnalyzer(forward_code, backward_code)
analyzer.print_relations()
print("="*50 + " 测试结束 " + "="*50)

another try to build the data dependence base on the 1st input of operand (limit)

豆包

How are you getting the backward fx graph? It’s easier if you are able to some setup there ahead of time, e.g. run with hooks. We have some infrastructure that does this association in torch.compile using this method.

Thanks for your attention.

I get the forward and backward fx graph base on make_fx, may be it is not a good method, so if you can share your ways base on hook is better.

def run_autograd_ahead_of_time_forward_only_2(function, dataloader):
    # 定义内部函数:执行模型前向传播并计算损失
    def run_and_get_loss(inputs, labels):
        outputs = function(inputs)          # 调用模型前向传播
        loss = criterion(outputs, labels)   # 计算损失函数
        return loss                        # 返回损失值
    
    # 遍历数据加载器
    for input_data, label_data in tqdm(dataloader, desc='Validation'):
        inputs = input_data              # 获取输入数据
        labels = label_data              # 获取标签数据
        
        # 导入FX符号追踪工具
        from torch.fx.experimental.proxy_tensor import make_fx
        
        
        # 创建FX符号追踪器,包装损失计算函数
        wrapped_function = make_fx(
            run_and_get_loss,             # 被追踪的目标函数
            tracing_mode="real",          # 使用真实张量追踪但记录符号形状
            _allow_non_fake_inputs=True   # 允许非FakeTensor输入
        )
        
        # 执行追踪,生成计算图
        joint_graph = wrapped_function(inputs, labels)
        interpreter = CaptureInterpreter(joint_graph)
        output = interpreter.run(inputs, labels)
        
        # 使用NPU后端编译计算图
        model_npu = torch.compile(
            joint_graph,                  # 待编译的FX图
            backend=npu_backend,          # NPU后端
            dynamic=False                 # 禁用动态形状
        )
        
        # 导出模型到TorchAIR格式
        tng.dynamo_export(
            inputs, labels,               # 输入数据和标签
            model=model_npu,              # 编译后的模型
            export_path="split_graph_fw_0609",  # 导出路径
            dynamic=False                 # 禁用动态导出
        )
        
        # atc --framework=1 --model=./export.air --output=./fw --input_format=NCHW --soc_version=Ascend310P3
        # 打印图的Python代码表示
        print(joint_graph._graph.python_code(root_module="self", verbose=True).src)
        break  # 仅处理一个批次

def run_autograd_ahead_of_time_backward_only_2(function, dataloader):
    # 定义内部函数:基于损失计算参数梯度
    def run_from_loss(loss):
        # 计算参数梯度,retain_graph=True允许多次反向
        grads = torch.autograd.grad(
            loss,                       # 损失函数作为梯度计算起点
            function.parameters(),      # 计算所有模型参数的梯度
            retain_graph=True,          # 保留计算图
            allow_unused=True           # 避免未参与计算的变量计算梯度时出错
        )
        return grads                   # 返回梯度列表
    
    # 遍历数据加载器
    for input_data, label_data in tqdm(dataloader, desc='Validation'):
        inputs = input_data              # 获取输入数据
        labels = label_data              # 获取标签数据
        outputs = function(inputs)       # 模型前向传播
        loss = criterion(outputs, labels) # 计算损失函数
        
        # 导入FX符号追踪工具
        from torch.fx.experimental.proxy_tensor import make_fx
        
        # 创建FX符号追踪器,包装梯度计算函数
        wrapped_function = make_fx(
            run_from_loss,                # 被追踪的目标函数(梯度计算)
            tracing_mode="real",          # 使用真实张量追踪
            _allow_non_fake_inputs=True   # 允许非FakeTensor输入
        )
        
        # 执行追踪,以损失为输入生成反向计算图
        joint_graph = wrapped_function(loss)
        interpreter = CaptureInterpreter(joint_graph)
        output = interpreter.run(loss)
        
        # 使用NPU后端编译计算图
        model_npu = torch.compile(
            joint_graph,                  # 待编译的FX图
            backend=npu_backend,          # NPU后端
            dynamic=False                 # 禁用动态形状
        )
        
        # 导出模型到TorchAIR格式
        tng.dynamo_export(
            loss,                         # 以损失作为导出输入
            model=model_npu,              # 编译后的模型
            export_path="split_graph_bw_0609",  # 导出路径
            dynamic=False                 # 禁用动态导出
        )
        
        # atc --framework=1 --model=./export.air --output=./bw --input_format=NCHW --soc_version=Ascend310P3
        # 打印图的Python代码表示
        print(joint_graph._graph.python_code(root_module="self", verbose=True).src)
        break  # 仅处理一个批次

# 数据预处理(确保输入是正确的2D张量)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.reshape(-1, 784))  # 确保为[batch, 784]
])

train_dt = datasets.MNIST(root='./data', train=True, transform=transform, download=False)
train_loader = torch.utils.data.DataLoader(dataset=train_dt, batch_size=32, shuffle=False)

run_autograd_ahead_of_time_forward_only_2(model, train_loader)     # pure forward graph
run_autograd_ahead_of_time_backward_only_2(model, train_loader)  # pure backward graph

Now, I find it may be a little easy to get the variable asociation relationship base on the pure backward graph and forward and backward graph as their operator type is same, but there is still some extra operators, such as detach will interference the matching.

  • the pure backward graph may be something like
def backward(self, loss_1: f32[]):
    # No stacktrace found for following nodes
    ones_like: f32[] = torch.ops.aten.ones_like.default(loss_1, pin_memory = False, memory_format = torch.preserve_format);  loss_1 = None
    _tensor_constant0 = self._tensor_constant0
    _tensor_constant1 = self._tensor_constant1
    _tensor_constant2 = self._tensor_constant2
    nll_loss_backward: f32[32, 10] = torch.ops.aten.nll_loss_backward.default(ones_like, _tensor_constant0, _tensor_constant1, None, 1, -100, _tensor_constant2);  ones_like = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
    _tensor_constant3 = self._tensor_constant3
    _log_softmax_backward_data: f32[32, 10] = torch.ops.aten._log_softmax_backward_data.default(nll_loss_backward, _tensor_constant3, 1, torch.float32);  nll_loss_backward = _tensor_constant3 = None
   ...
  • the pure forward and backward graph may be something like
def forward_and_backward(self, inputs_1: f32[32, 1, 784], labels_1: i64[32]):
    # ...  delelte the forward part
    
    ones_like: f32[] = torch.ops.aten.ones_like.default(getitem, pin_memory = False, memory_format = torch.preserve_format)
    nll_loss_backward: f32[32, 10] = torch.ops.aten.nll_loss_backward.default(ones_like, _log_softmax, labels_1, None, 1, -100, getitem_1);  ones_like = _log_softmax = labels_1 = getitem_1 = None
    detach_3: f32[32, 10] = torch.ops.aten.detach.default(detach_2);  detach_2 = None
    _log_softmax_backward_data: f32[32, 10] = torch.ops.aten._log_softmax_backward_data.default(nll_loss_backward, detach_3, 1, torch.float32);  nll_loss_backward = detach_3 = None

   ...

then I hope to get the variable asociation relationship such as:
{ “loss_1”: “getitem”, “_log_softmax”: “_tensor_constant0”, … }