Hi,
according to the instruction here: torch.fx — PyTorch main documentation
the return data of torch.fx.tracer.trace()
and torch.fx.symbolic_trace()
is graph
.
however, the data type of fx_model
is GraphModule
in the following code:
fx_model = torch.fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
for node in fx_model.graph.nodes:
if node.target in modules:
print(node.target)
if type(modules[node.target]) == torch.nn.Conv2d:
print('conv node')