Let’s say I have the following model:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.lin1 = nn.Linear(784, 128)
self.lin2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.lin1(x))
x = self.lin2(x)
return x
I want to be able to inspect the network somehow, and get the sequential operations performed by it. Using .children()
gives the following:
network = Net()
print(list(network.children()))
Out: "[Linear(in_features=784, out_features=128, bias=True), Linear(in_features=128, out_features=10, bias=True)]"
How can I detect the torch.nn.functional
calls in a Module
's forward pass?