Understanding structure of a pytorch graph

Hi,

For my use case, I require to be able to take a pytorch module and interpret the sequence of layers in the module so that I can create a “connection” between the layers in some file format. Now let’s say I have a simple module as below

class mymodel(nn.Module):
	def __init__(self, input_channels):
		super(mymodel, self).__init__()
		self.fc = nn.Linear(input_channels, input_channels)
	def forward(self, x):
		out = self.fc(x)
		out += x
		return out


if __name__ == "__main__":
	net = mymodel(5)

	for mod in net.modules():
		print(mod) 

Here the output yields:

mymodel(
  (fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)

as you can see the information about the plus equals operation or plus operation is not captured as it is not a nnmodule in the forward function. My goal is to be able to create a graph connection from the pytorch module object to say something like this in json :

layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}

The input tensor names are arbitrary but it captures the essence of the graph and the connections between layers. Is there a way to extract the information from a pytorch object? I was thinking to use the .modules() but then realized that hand written operations are not captured this way as a module. I guess if everything is an nn.module then the .modules() might give me the network layer arrangement. Looking for some help here. Thanks!

Check out the torchviz (pip install torchviz). This tool uses graphviz to print the graphs.

thanks! I was looking for something programmatic than visual to be able to store information in the format mentioned above. Any suggestions?

try model.named_parameters()