Detecting functional calls in forward()

Let’s say I have the following model:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.lin1 = nn.Linear(784, 128)
        self.lin2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x

I want to be able to inspect the network somehow, and get the sequential operations performed by it. Using .children() gives the following:

network = Net()

print(list(network.children()))

Out: "[Linear(in_features=784, out_features=128, bias=True), Linear(in_features=128, out_features=10, bias=True)]"

How can I detect the torch.nn.functional calls in a Module's forward pass?

1 Like

For your aim, the best bet is to have a model defined as a nn.Sequential and print that.

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.seq_model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10))

    def forward(self, x):
        x = self.seq_model(x)
        return x

I’m hoping to detect calls to torch.nn.functional. I know I can use a Sequential module, but I’m hoping to detect functional operations in, say, pre-trained networks.

1 Like

Did you find a solution to this ? A long routed way would be to run a dummy input and trace all the function calls. but that is not solving my problem as I intend to replace the modules inplace.