Transforming a model based on graph information

I am writing a graph transform that needs to make modifications on a model based on graph information. My question is, how do I keep references between module names in the model and node names in the graph?

In other words, how do I reference parameters in the parent module when I am in a child module in the model? As far as I can tell, there is no graph information until the model has been compiled.

For example, let’s look at a snippet from Resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  ...
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    ...

The node layer2.0.downsample.0 gets its input from the node bn1, but there is no way to tell this until the graph is compiled. In my case, when I traverse the resnet model and I get to layer2.0.downsample.0, I want to know the number of out_channels from the parent node.

Stealing some ideas from the hiddenlayer project, I know I can use the jit compiler to get the graph

trace, out = torch.jit.get_trace_graph(model, torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
torch_graph = trace.graph()

which can then be traversed in trace order. However, all the model names and such are lost at this point. I could write my transformer by compiling a lookup table by traversing through the graph in trace order, and then applying the changes from the lookup table back to the model. But how would I keep references to the original modules while traversing the graph?

Any help is welcome, thanks!

You don’t know what your parent is. And that’s a design decision rather than a missing feature. For one thing you can use a single model several times.

If you want to pass runtime information to modules, pass them in as arguments (or use a functional interface in the first place).

Best regards

Thomas