Hi,
As you might know that PyTorch creates graphs dynamically when you provide data to network, you can get it’s behavior by applying an input and tracing graph construction. Here is the code:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 300)
self.fc2 = nn.Linear(300, 100)
self.fc3 = nn.Linear(100, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return F.log_softmax(x, dim=1)
net = Net()
x = torch.randn(1, 28*28)
traced_net = torch.jit.trace(net, x)
print(traced_net)
print(traced_net.code)
### prints
#Net(
# original_name=Net
# (fc1): Linear(original_name=Linear)
# (fc2): Linear(original_name=Linear)
# (fc3): Linear(original_name=Linear)
#)
#def forward(self,
# x: Tensor) -> Tensor:
# _0 = self.fc3
# _1 = self.fc2
# _2 = self.fc1
# input = torch.view(x, [-1, 784])
# input0 = torch.relu((_2).forward(input, ))
# input1 = torch.relu((_1).forward(input0, ))
# _3 = torch.log_softmax((_0).forward(input1, ), 1, None)
# return _3
Please read this tutorial about TorchScript.
Bests