ONNX Hierarchical Trace

I have a simple model that consists of two parts:

import torch
import torch.nn as nn

class Part1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.conv2 = nn.Conv2d(16,32,3)
        self.maxpool1 = nn.MaxPool2d(32,2)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.maxpool1(x)
        return x


class Part2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(32,64,3)
        self.conv2 = nn.Conv2d(64,128,3)
        self.maxpool1 = nn.MaxPool2d(64,2)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.maxpool1(x)
        return x


class Wrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.part1 = Part1()
        self.part2 = Part2()
    def forward(self, x):
        x = self.part1(x)
        x = self.part2(x)
        return x

If I trace the Wrapper model using ONNX:

dummy_input = torch.randn(1,3,224,224)
model = Wrapper()
torch.onnx.export(model, dummy_input, "wrapper.onnx", verbose=True, opset_version=11)

I get the following output:

graph(%input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu),
      %part1.conv1.weight : Float(16:27, 3:9, 3:3, 3:1, requires_grad=1, device=cpu),
      %part1.conv1.bias : Float(16:1, requires_grad=1, device=cpu),
      %part1.conv2.weight : Float(32:144, 16:9, 3:3, 3:1, requires_grad=1, device=cpu),
      %part1.conv2.bias : Float(32:1, requires_grad=1, device=cpu),
      %part2.conv1.weight : Float(64:288, 32:9, 3:3, 3:1, requires_grad=1, device=cpu),
      %part2.conv1.bias : Float(64:1, requires_grad=1, device=cpu),
      %part2.conv2.weight : Float(128:576, 64:9, 3:3, 3:1, requires_grad=1, device=cpu),
      %part2.conv2.bias : Float(128:1, requires_grad=1, device=cpu)):
  %9 : Float(1:788544, 16:49284, 222:222, 222:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%input.1, %part1.conv1.weight, %part1.conv1.bias)
  %10 : Float(1:1548800, 32:48400, 220:220, 220:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%9, %part1.conv2.weight, %part1.conv2.bias) 
  %11 : Float(1:288800, 32:9025, 95:95, 95:1, requires_grad=1, device=cpu) = onnx::MaxPool[ceil_mode=0, kernel_shape=[32, 32], pads=[0, 0, 0, 0], strides=[2, 2]](%10)
  %12 : Float(1:553536, 64:8649, 93:93, 93:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%11, %part2.conv1.weight, %part2.conv1.bias)
  %13 : Float(1:1059968, 128:8281, 91:91, 91:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%12, %part2.conv2.weight, %part2.conv2.bias)
  %14 : Float(1:25088, 128:196, 14:14, 14:1, requires_grad=1, device=cpu) = onnx::MaxPool[ceil_mode=0, kernel_shape=[64, 64], pads=[0, 0, 0, 0], strides=[2, 2]](%13)
  return (%14)

which is fine.

But as you can see, the model is originally hierarchical as it consists of parts. However, this hierarchy is somehow flattened in the trace and there is no way to tell the part to which a layer belongs.

I might be able to know to which part each conv layer belongs because the parameters indicate this (e.g. part1.conv1.weight) but this is not guaranteed and I can’t find out from the trace which part the first pooling layer belongs to.

Is there a way to keep the hierarchy while tracing or at least indicate it?

In other words, is there a way to change %9, %10, etc to something that indicates to which part each layer belongs (is it even possible to change them anyway?)

Looks like the scope is no more available since PyTorch 1.4
There is a workaroud to get the scope mentioned here:

while iterating over the graph nodes, the scope name can be obtained using: node.scopeName()