I have a simple model that consists of two parts:
import torch
import torch.nn as nn
class Part1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3,16,3)
self.conv2 = nn.Conv2d(16,32,3)
self.maxpool1 = nn.MaxPool2d(32,2)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.maxpool1(x)
return x
class Part2(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(32,64,3)
self.conv2 = nn.Conv2d(64,128,3)
self.maxpool1 = nn.MaxPool2d(64,2)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.maxpool1(x)
return x
class Wrapper(nn.Module):
def __init__(self):
super().__init__()
self.part1 = Part1()
self.part2 = Part2()
def forward(self, x):
x = self.part1(x)
x = self.part2(x)
return x
If I trace the Wrapper model using ONNX:
dummy_input = torch.randn(1,3,224,224)
model = Wrapper()
torch.onnx.export(model, dummy_input, "wrapper.onnx", verbose=True, opset_version=11)
I get the following output:
graph(%input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu),
%part1.conv1.weight : Float(16:27, 3:9, 3:3, 3:1, requires_grad=1, device=cpu),
%part1.conv1.bias : Float(16:1, requires_grad=1, device=cpu),
%part1.conv2.weight : Float(32:144, 16:9, 3:3, 3:1, requires_grad=1, device=cpu),
%part1.conv2.bias : Float(32:1, requires_grad=1, device=cpu),
%part2.conv1.weight : Float(64:288, 32:9, 3:3, 3:1, requires_grad=1, device=cpu),
%part2.conv1.bias : Float(64:1, requires_grad=1, device=cpu),
%part2.conv2.weight : Float(128:576, 64:9, 3:3, 3:1, requires_grad=1, device=cpu),
%part2.conv2.bias : Float(128:1, requires_grad=1, device=cpu)):
%9 : Float(1:788544, 16:49284, 222:222, 222:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%input.1, %part1.conv1.weight, %part1.conv1.bias)
%10 : Float(1:1548800, 32:48400, 220:220, 220:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%9, %part1.conv2.weight, %part1.conv2.bias)
%11 : Float(1:288800, 32:9025, 95:95, 95:1, requires_grad=1, device=cpu) = onnx::MaxPool[ceil_mode=0, kernel_shape=[32, 32], pads=[0, 0, 0, 0], strides=[2, 2]](%10)
%12 : Float(1:553536, 64:8649, 93:93, 93:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%11, %part2.conv1.weight, %part2.conv1.bias)
%13 : Float(1:1059968, 128:8281, 91:91, 91:1, requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%12, %part2.conv2.weight, %part2.conv2.bias)
%14 : Float(1:25088, 128:196, 14:14, 14:1, requires_grad=1, device=cpu) = onnx::MaxPool[ceil_mode=0, kernel_shape=[64, 64], pads=[0, 0, 0, 0], strides=[2, 2]](%13)
return (%14)
which is fine.
But as you can see, the model is originally hierarchical as it consists of parts. However, this hierarchy is somehow flattened in the trace and there is no way to tell the part to which a layer belongs.
I might be able to know to which part each conv layer belongs because the parameters indicate this (e.g. part1.conv1.weight) but this is not guaranteed and I can’t find out from the trace which part the first pooling layer belongs to.
Is there a way to keep the hierarchy while tracing or at least indicate it?