After tracing an nn.Module
with torch.fx’s tracer, I can tell the hierarchical information of operator nodes with opcode of “call_module” by looking at their target
attribute (the qualified name) of the operator node. For example, SubMod.conv
tells that the convolution layer conv
is within the SubMod
module. However, I am not able to do something similar for operator nodes with opcode of “call_method”, “call_function”, “placeholder”, “output”, etc. The ones that I have particular interests in are the ones with opcode of “call_method” and “call_function”, and I am wondering if there’s any way to extract their location information in the hierarchical structure of a nested module. When I look at the target
attribute of the operator node with opcode of “call_method” or “call_function”, I usually see <built-in function
and sometimes along with the hex address of the function or method object, but it’s hard to infer the structure information from that.
Being able to obtain the “call_method” and “call_function” operator nodes location information in the hierarchical structure is particularly essential and helpful for distinguishing models like the following two after tracing them with torch.fx’s tracer: the two models below (Model1
and Model2
) are structurally different in PyTorch code; yet after tracing, the printed graph/operator nodes tables for both of the modules are identical (please see screenshot below for the printed graph/operator nodes table output and also the code for reproducing this). I am wondering if there’s any good way to obtain this information. I would really appreciate any suggestions, thank you very much in advance!
import torch
import torch.nn as nn
import torch.fx as fx
class SubModule1(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, (3, 3))
def forward(self, x):
x = self.conv(x)
x = x + x
return x
class Model1(nn.Module):
def __init__(self):
super().__init__()
self.SubMod = SubModule1()
def forward(self, x):
x = self.SubMod(x)
return x
class SubModule2(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, (3, 3))
def forward(self, x):
x = self.conv(x)
return x
class Model2(nn.Module):
def __init__(self):
super().__init__()
self.SubMod = SubModule2()
def forward(self, x):
x = self.SubMod(x)
x = x + x
return x
def print_graph_table(module):
tracer_class = fx.Tracer
graph = tracer_class().trace(module)
graph.print_tabular()
model1 = Model1()
print("model1 Graph Table:")
print_graph_table(model1)
print()
model2 = Model2()
print("model2 Graph Table:")
print_graph_table(model2)