Node output shape from trace graph

I’m trying to traverse a graph that I got from tracing the model. I got most of the data I need, except for one: the shape of the output of each layer.

# VGG16
model = torchvision.models.vgg16()
# Get trace graph
trace, out = torch.jit.get_trace_graph(model, torch.zeros([1, 3, 224, 224]))
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
graph = trace.graph()
# Print first node in the graph
node = list(graph.nodes())[0]
print(node)

Output:

%33 : Float(1, 64, 224, 224) = onnx::Conv[dilations=[1, 1], group=1, 
kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%0, %1, %2), 
scope: VGG/Sequential[features]/Conv2d[0]

See the Float(1, 64, 224, 224) in the output? That’s the value I want. Which method or property returns that?

2 Likes

Bump!

Still searching for a solution. As a temporary hack, I used a regular expression to extract the shape from the string, but there must be a better way to do this. At least I hope there is.

Have you found any better way? I need to get output shape of a C node as well

Hi, I’ve found how to acquire output shape of a torch._C.Node.
let node be a torch._C.Node. The output shape of this Node is node.output().type().sizes()
Hope this will help

I was trying to replicate the examples in here, but my graphs were not showing any tensor shapes. I have then implemented the change cjsoft suggested by substituting the default

m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
if m:
    shape = m.group(1)
    shape = shape.split(",")
    shape = tuple(map(int, shape))
else:
    shape = None
return shape

in get_shape within pytorch_builder.py with

shape = torch_node.output().type().sizes()
return shape

which solved the issue. However, when trying to plot my own network (a ResNet U-Net), it incurred this error:

RuntimeError: r INTERNAL ASSERT FAILED at "..\\aten\\src\\ATen/core/jit_type.h":171, please report a bug to PyTorch.

So I have wrapped everything into a try/except and it seems to work fine:

try:
    shape = torch_node.output().type().sizes()
except:
    shape = None
return shape

Hope it helps.