How to print/iterate through the network structure in PyTorch?

PyTorch doesn’t create a static graph representation of a neural network. So, we can not iterate through the network’s layers once it is created (and query or print them). Is it possible to get the information about the network structure at runtime? Alternatively, is there a static mode that we can run PyTorch in?

Here’s my usecase: I am taking a neural network from the pytorch hub and trying to perform some analysis on the network. So, I need to iterate over layers in the network and query for things like tensor sizes and next/previous layer names.

There are multiple ways to list out or iterate over the flattened list of layers in the network (including Keras style model.summary from sksq96’s pytorch-summary github). But the problem with these methods is that they don’t provide information about the edges of the neural network graph (eg. which layer was before a particular layer, or which layer does this layer feed into, etc). These lists also seems to miss any the elementwise addition layers created using the “+” operator in the network code.

I searched around and do see some posts that talk about it, but haven’t found a solution:

python - How do I get the precursor nodes of each layer in Pytorch? - Stack Overflow

How to traverse a network - PyTorch Forums

Note that the networks are not just defined using nn.Sequential. They can be complex like containing multiple nn.Sequential blocks and other modules.

Any help/advice is appreciated.

Scripted or traced model provide a representation of the model, which you could probably use for your experiments.
Have you checked the created code from a scripted model and would it be useful in your use case?

1 Like

Thanks for your response, Piotr :slight_smile:
I haven’t heard of scripted or traced models. Lemme check it out.

1 Like