PyTorch doesn’t create a static graph representation of a neural network. So, we can not iterate through the network’s layers once it is created (and query or print them). Is it possible to get the information about the network structure at runtime? Alternatively, is there a static mode that we can run PyTorch in?
Here’s my usecase: I am taking a neural network from the pytorch hub and trying to perform some analysis on the network. So, I need to iterate over layers in the network and query for things like tensor sizes and next/previous layer names.
There are multiple ways to list out or iterate over the flattened list of layers in the network (including Keras style model.summary from sksq96’s pytorch-summary github). But the problem with these methods is that they don’t provide information about the edges of the neural network graph (eg. which layer was before a particular layer, or which layer does this layer feed into, etc). These lists also seems to miss any the elementwise addition layers created using the “+” operator in the network code.
I searched around and do see some posts that talk about it, but haven’t found a solution:
Note that the networks are not just defined using nn.Sequential. They can be complex like containing multiple nn.Sequential blocks and other modules.
Any help/advice is appreciated.