Hi, I already asked a question about it but after coming back to it I think the answer is missing some points.
So I would like to traverse the graph of modules, starting from the input in a depth/breadth first search manner.
My goal is to automatically search & replace blocks (subgraphs) in a model (graph).
Here is my previous answered question
So after looking at the doc it seems that torch.nn.Module.named_modules()
iterates over all torch.nn.Module
found in the constructor, I imagine respecting the order in which they are declared. The problem is that it does not necessarily follow a depth first/breadth first, even more if the network doesn’t follow a sequential structure.
The computation graph is defined by the order of computation in the forward method, not the order of declaration in the constructor…
Another (“more minor”) problem is that when using functional api for some computation (like using torch.nn.functional.relu()
in the forward
method), named_modules()
can’t see it.
So I am still looking for a way to do a depth/breadth first graph traversal in order to operate search and replace of blocks.
Is there a way I can achieve this?