- For example, we have a string format of a fx ir, is there any API can obtain input variable of each operand of the fx ir directly?
def backward(self, grad_output):
threshold_backward: f32[32, 64] = torch.ops.aten.threshold_backward.default(grad_output, 0)
mm: f32[64, 784] = torch.ops.aten.mm.default(threshold_backward, view)
view_backward: f32[32, 1, 784] = torch.ops.aten.view_backward.default(mm, x.shape)
return view_backward
- It is good to get some result similar to the deepseek, but I don’t sure it is robust for complex FX ir, so it’s better if there’s a direct api.
https://chat.deepseek.com/a/chat/s/ff453434-bd94-4e30-9806-db571ed9b074
import re
def parse_fx_ir(ir_str):
# 正则表达式匹配节点模式:节点名: 类型 = 操作符(输入参数)
pattern = r'(\w+)\s*:\s*[^=]+\s*=\s*([\w.]+)\(([^)]*)\)'
nodes = []
lines = ir_str.strip().split('\n')
for line in lines:
line = line.strip()
if line.startswith("return") or not line: # 跳过返回语句和空行
continue
match = re.match(pattern, line)
if match:
node_name = match.group(1)
operator = match.group(2)
inputs_str = match.group(3)
# 分割输入参数(处理嵌套结构如元组)
inputs = []
bracket_count = 0
current = []
for char in inputs_str:
if char == '(':
bracket_count += 1
elif char == ')':
bracket_count -= 1
if char == ',' and bracket_count == 0:
inputs.append(''.join(current).strip())
current = []
else:
current.append(char)
if current: # 添加最后一个参数
inputs.append(''.join(current).strip())
nodes.append({
"node_name": node_name,
"operator": operator,
"inputs": inputs,
"output": node_name # FX IR 中节点名即输出变量名
})
return nodes
# 示例 IR 字符串
ir_str = """
threshold_backward: f32[32, 64] = torch.ops.aten.threshold_backward.default(grad_output, 0)
mm: f32[64, 784] = torch.ops.aten.mm.default(threshold_backward, view)
view_backward: f32[32, 1, 784] = torch.ops.aten.view_backward.default(mm, x.shape)
return view_backward
"""
# 解析并打印结果
parsed_nodes = parse_fx_ir(ir_str)
item = 0
for node in parsed_nodes:
print(f"节点 {item} : {node['node_name']}")
print(f" 操作符: {node['operator']}")
print(f" 输入: {node['inputs']}")
print(f" 输出: {node['output']}\n")
item = item + 1