Hi folks I’m building a proof of concept where I detect large convolutions and break them apart into smaller ones for devices with small on device memory
As a first step I wanted to just detect where a large convolution may occur
To do so I’m using torch.fx
to extract out in_channels
, out_channels
and the stride
and if they’re larger than a certain threshold then I know I have a big convolution
Unfortunately I can’t seem to find where the parameters of a convolution operation are stored on the node https://github.com/…/pytorch/blob/master/torch/fx/node.py
I know the trace is aware of the parameters if I just run symbolic_trace(model)
Any suggestions? Should I be working with the code
object instead of the graph
object?
More specifically
class SplittableCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels= 32, out_channels= 64, kernel_size= 1, stride=1)
self.conv2 = nn.Conv2d(in_channels= 64, out_channels= 128, kernel_size= 1, stride=1)
self.conv3 = nn.Conv2d(in_channels= 128, out_channels= 32, kernel_size= 1, stride=2)
def forward(self, x):
output = self.conv3(self.conv2(self.conv1(x)))
return output
model = SplittableCNN()
symbolic_trace(model)
# outputs
SplittableCNN(
(conv1): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
(conv2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(128, 32, kernel_size=(1, 1), stride=(2, 2))
)
I want to extract the in_channels
and out_channels
and at a later stage change them
`