I want to use symbolic_trace from torch.fx to trace the model.
symbolic_trace will return a GraphModule,how can i get the input_shape and output_shape of each node in the graphModule.graph.
the graph module’s graph has a list nodes and you can check for this information on the meta attribute, e.g.
for output_shape, do list(gm.graph.nodes)[i].meta
for input_shape, you can grab the args of the node and list(gm.graph.nodes)[i].args[j].meta
thanks,bro.I will try it