I’m trying to traverse a graph that I got from tracing the model. I got most of the data I need, except for one: the shape of the output of each layer.
# VGG16
model = torchvision.models.vgg16()
# Get trace graph
trace, out = torch.jit.get_trace_graph(model, torch.zeros([1, 3, 224, 224]))
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
graph = trace.graph()
# Print first node in the graph
node = list(graph.nodes())[0]
print(node)
Still searching for a solution. As a temporary hack, I used a regular expression to extract the shape from the string, but there must be a better way to do this. At least I hope there is.
Hi, I’ve found how to acquire output shape of a torch._C.Node.
let node be a torch._C.Node. The output shape of this Node is node.output().type().sizes()
Hope this will help
I was trying to replicate the examples in here, but my graphs were not showing any tensor shapes. I have then implemented the change cjsoft suggested by substituting the default