I am writing a graph transform that needs to make modifications on a model based on graph information. My question is, how do I keep references between module names in the model and node names in the graph?
In other words, how do I reference parameters in the parent module when I am in a child module in the model? As far as I can tell, there is no graph information until the model has been compiled.
For example, let’s look at a snippet from Resnet18
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
...
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
...
The node layer2.0.downsample.0
gets its input from the node bn1
, but there is no way to tell this until the graph is compiled. In my case, when I traverse the resnet model and I get to layer2.0.downsample.0
, I want to know the number of out_channels
from the parent node.
Stealing some ideas from the hiddenlayer
project, I know I can use the jit compiler to get the graph
trace, out = torch.jit.get_trace_graph(model, torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
torch_graph = trace.graph()
which can then be traversed in trace order. However, all the model names and such are lost at this point. I could write my transformer by compiling a lookup table by traversing through the graph in trace order, and then applying the changes from the lookup table back to the model. But how would I keep references to the original modules while traversing the graph?
Any help is welcome, thanks!