Graph traversal of torch.nn.Module

Hi, I already asked a question about it but after coming back to it I think the answer is missing some points.
So I would like to traverse the graph of modules, starting from the input in a depth/breadth first search manner.
My goal is to automatically search & replace blocks (subgraphs) in a model (graph).

Here is my previous answered question

So after looking at the doc it seems that torch.nn.Module.named_modules() iterates over all torch.nn.Module found in the constructor, I imagine respecting the order in which they are declared. The problem is that it does not necessarily follow a depth first/breadth first, even more if the network doesn’t follow a sequential structure.
The computation graph is defined by the order of computation in the forward method, not the order of declaration in the constructor…
Another (“more minor”) problem is that when using functional api for some computation (like using torch.nn.functional.relu() in the forward method), named_modules() can’t see it.

So I am still looking for a way to do a depth/breadth first graph traversal in order to operate search and replace of blocks.
Is there a way I can achieve this?

Hi,

One thing you most likely want to checkout is torch.fx — PyTorch master documentation

It looks promising, and will probably help me achieving what I want to do right now.
Any idea on when it could be released?

It will be a beta feature in the upcoming 1.8 release.

Right, I imagine you don’t want to adventure into release date forecast :sweat_smile:
Can’t wait for it! :grinning:

Cheers

The release branch is planned to be cut later today.
So between testing, release notes and all, you should have it at the end of the month :slight_smile:

Wow, good work guys that’s awesome!