Extract layer parameters

Hi folks I’m building a proof of concept where I detect large convolutions and break them apart into smaller ones for devices with small on device memory

As a first step I wanted to just detect where a large convolution may occur

To do so I’m using torch.fx to extract out in_channels, out_channels and the stride and if they’re larger than a certain threshold then I know I have a big convolution

Unfortunately I can’t seem to find where the parameters of a convolution operation are stored on the node https://github.com/…/pytorch/blob/master/torch/fx/node.py

I know the trace is aware of the parameters if I just run symbolic_trace(model)

Any suggestions? Should I be working with the code object instead of the graph object?

More specifically

class SplittableCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels= 32, out_channels= 64, kernel_size= 1, stride=1)
        self.conv2 = nn.Conv2d(in_channels= 64, out_channels= 128, kernel_size= 1, stride=1)
        self.conv3 = nn.Conv2d(in_channels= 128, out_channels= 32, kernel_size= 1, stride=2)
    def forward(self, x):
        output = self.conv3(self.conv2(self.conv1(x)))
        return output
model = SplittableCNN() 
symbolic_trace(model)

# outputs
SplittableCNN(
  (conv1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(128, 32, kernel_size=(1, 1), stride=(2, 2))
)

I want to extract the in_channels and out_channels and at a later stage change them
`

The module isn’t stored on the node itself, but the node contains the qualified name of the module. So you can use the qualified name to query the module itself to obtain the actual module, with which you can just query the parameters.

See pytorch/optimization.py at master · pytorch/pytorch · GitHub

This was really cool

trace = symbolic_trace(model)
modules = dict(trace.named_modules())

modules['conv1'].in_channels → would correctly output 32

1 Like