- I have a simple code to get the variable asociation relationship of FX graph between forward and reverse, which is base on the relationship of operator types with help of deepcode.
https://chat.deepseek.com/a/chat/s/7ce16ac6-7fc6-4cbf-bc03-04ebf4636ef3
- However, there are many limitations, such as the inability to handle
a) duplicate operands in an fx graph
b) many output for a operator
…
Does any idea to improve it ?
import re
from collections import defaultdict
class FXGraphRelationAnalyzer:
def __init__(self, forward_code, backward_code):
self.forward_code = forward_code
self.backward_code = backward_code
self.forward_nodes = self._parse_code(forward_code)
self.backward_nodes = self._parse_code(backward_code)
self.op_mapping = self._build_op_mapping()
self.relations = self._build_relations()
def _parse_code(self, code):
nodes = {}
lines = code.strip().split('\n')
# 改进的正则表达式:匹配FX图格式
# 示例: "view: f32[32, 784] = torch.ops.aten.view.default(x, [32, -1])"
assign_pattern = re.compile(r'^(\w+)(?::\s*[^=]+)?\s*=\s*(\w+(?:\.\w+)*\.\w+)')
print(f"\n解析代码: {code[:50]}...")
for i, line in enumerate(lines):
line = line.strip()
if not line or line.startswith(('#', 'def', 'return', '@')):
continue
# 调试输出
print(f"解析第 {i+1} 行: '{line}'")
match = assign_pattern.match(line)
if match:
var_name, op_name = match.groups()
print(f" 匹配成功: 变量={var_name}, 操作={op_name}")
nodes[var_name] = {
'operation': op_name,
'full_line': line
}
else:
print(f" 无法解析的行: '{line}'")
print(f"解析完成,找到 {len(nodes)} 个节点")
return nodes
def _build_op_mapping(self):
# 更准确的正反向操作映射
return {
'torch.ops.aten.view.default': 'torch.ops.aten.view_backward.default',
'torch.ops.aten.addmm.default': 'torch.ops.aten.mm.default',
'torch.ops.aten.relu.default': 'torch.ops.aten.threshold_backward.default',
'torch.ops.aten.linear.default': 'torch.ops.aten.linear_backward.default',
'torch.ops.aten.sigmoid.default': 'torch.ops.aten.sigmoid_backward.default',
'torch.ops.aten.nll_loss_forward.default': 'torch.ops.aten.nll_loss_backward.default',
'torch.ops.aten.t.default': 'torch.ops.aten.t_backward.default',
'torch.ops.aten.sum.default': 'torch.ops.aten.sum_backward.default',
'torch.ops.aten.mm.default': 'torch.ops.aten.mm_backward.default',
'torch.ops.aten.clone.default': 'torch.ops.aten.clone_backward.default',
}
def _build_relations(self):
relations = defaultdict(list)
print("\n=== 前向节点分析 ===")
for fwd_var, fwd_info in self.forward_nodes.items():
print(f"前向变量: {fwd_var}, 操作: {fwd_info['operation']}")
print("\n=== 反向节点分析 ===")
for bwd_var, bwd_info in self.backward_nodes.items():
print(f"反向变量: {bwd_var}, 操作: {bwd_info['operation']}")
print("\n=== 关联关系构建 ===")
for fwd_var, fwd_info in self.forward_nodes.items():
fwd_op = fwd_info['operation']
# 查找对应的反向操作
target_bwd_op = self.op_mapping.get(fwd_op)
if not target_bwd_op:
print(f"警告: 没有为前向操作 {fwd_op} 找到映射的反向操作")
continue
print(f"查找 {fwd_op} 的反向操作 {target_bwd_op}")
for bwd_var, bwd_info in self.backward_nodes.items():
if bwd_info['operation'] == target_bwd_op:
print(f" 找到匹配: {fwd_var} ({fwd_op}) → {bwd_var} ({target_bwd_op})")
relations[fwd_var].append({
'backward_var': bwd_var,
'relationship': 'op_type_mapping',
'forward_op': fwd_op,
'backward_op': target_bwd_op
})
# 添加额外关系:参数依赖
for bwd_var, bwd_info in self.backward_nodes.items():
# 在反向代码中查找前向变量名
for fwd_var in self.forward_nodes:
if f' {fwd_var},' in bwd_info['full_line'] or f' {fwd_var})' in bwd_info['full_line']:
print(f" 参数依赖: 反向操作 {bwd_var} 依赖前向变量 {fwd_var}")
relations[fwd_var].append({
'backward_var': bwd_var,
'relationship': 'parameter_dependency',
'forward_op': self.forward_nodes[fwd_var]['operation'],
'backward_op': bwd_info['operation']
})
return relations
def print_relations(self):
print("\n===== 前向-反向变量关联关系 =====")
if not self.relations:
print("没有找到任何关联关系")
return
for fwd_var, rels in self.relations.items():
print(f"\n前向变量: {fwd_var} ({self.forward_nodes[fwd_var]['operation']})")
for rel in rels:
relation_type = "操作类型映射" if rel['relationship'] == 'op_type_mapping' else "参数依赖"
print(f" → 反向变量: {rel['backward_var']} ({rel['backward_op']})")
print(f" 关系类型: {relation_type}")
# 测试代码(使用更真实的FX图格式)
forward_code = """
def forward(self, x):
view: f32[32, 784] = torch.ops.aten.view.default(x, [32, -1])
addmm: f32[32, 64] = torch.ops.aten.addmm.default(w1, view, b1)
relu: f32[32, 64] = torch.ops.aten.relu.default(addmm)
return relu
"""
backward_code = """
def backward(self, grad_output):
threshold_backward: f32[32, 64] = torch.ops.aten.threshold_backward.default(grad_output, 0)
mm: f32[64, 784] = torch.ops.aten.mm.default(threshold_backward, view)
view_backward: f32[32, 1, 784] = torch.ops.aten.view_backward.default(mm, x.shape)
return view_backward
"""
print("="*50 + " 开始测试 " + "="*50)
analyzer = FXGraphRelationAnalyzer(forward_code, backward_code)
analyzer.print_relations()
print("="*50 + " 测试结束 " + "="*50)