How to obtain input variable of each operand of a fx ir

  • For example, we have a string format of a fx ir, is there any API can obtain input variable of each operand of the fx ir directly?
def backward(self, grad_output):
    threshold_backward: f32[32, 64] = torch.ops.aten.threshold_backward.default(grad_output, 0)
    mm: f32[64, 784] = torch.ops.aten.mm.default(threshold_backward, view)
    view_backward: f32[32, 1, 784] = torch.ops.aten.view_backward.default(mm, x.shape)
    return view_backward
  • It is good to get some result similar to the deepseek, but I don’t sure it is robust for complex FX ir, so it’s better if there’s a direct api.

https://chat.deepseek.com/a/chat/s/ff453434-bd94-4e30-9806-db571ed9b074

import re

def parse_fx_ir(ir_str):
    # 正则表达式匹配节点模式:节点名: 类型 = 操作符(输入参数)
    pattern = r'(\w+)\s*:\s*[^=]+\s*=\s*([\w.]+)\(([^)]*)\)'
    nodes = []
    lines = ir_str.strip().split('\n')
    
    for line in lines:
        line = line.strip()
        if line.startswith("return") or not line:  # 跳过返回语句和空行
            continue
            
        match = re.match(pattern, line)
        if match:
            node_name = match.group(1)
            operator = match.group(2)
            inputs_str = match.group(3)
            
            # 分割输入参数(处理嵌套结构如元组)
            inputs = []
            bracket_count = 0
            current = []
            for char in inputs_str:
                if char == '(':
                    bracket_count += 1
                elif char == ')':
                    bracket_count -= 1
                if char == ',' and bracket_count == 0:
                    inputs.append(''.join(current).strip())
                    current = []
                else:
                    current.append(char)
            if current:  # 添加最后一个参数
                inputs.append(''.join(current).strip())
            
            nodes.append({
                "node_name": node_name,
                "operator": operator,
                "inputs": inputs,
                "output": node_name  # FX IR 中节点名即输出变量名
            })
    return nodes

# 示例 IR 字符串
ir_str = """
threshold_backward: f32[32, 64] = torch.ops.aten.threshold_backward.default(grad_output, 0)
mm: f32[64, 784] = torch.ops.aten.mm.default(threshold_backward, view)
view_backward: f32[32, 1, 784] = torch.ops.aten.view_backward.default(mm, x.shape)
return view_backward
"""

# 解析并打印结果
parsed_nodes = parse_fx_ir(ir_str)
item = 0
for node in parsed_nodes:
    print(f"节点 {item} : {node['node_name']}")
    print(f"  操作符: {node['operator']}")
    print(f"  输入: {node['inputs']}")
    print(f"  输出: {node['output']}\n")
    item = item + 1

It seem m_export.module().graph.print_tabular() hold the information, but its base on a model, while my input is fx ir, any way to the build graph base on the fx ir ?

How to get positional order of inputs and outputs for a graph that was exported via torch.export?