List layers/nodes executed in order of forward pass

I am trying to implement a CNN visualization technique. I essentially need a list of operations executed in the order of the forward pass, that the algorithm can work through backwards.

Due to the nature of the algorithm, operations such as pooling and reshaping must also be registered.

I could maybe have the user manually make a list of operations, but it isn’t ideal as I also use forward hooks to grab data between the modules.

I found a topic that said I could maybe use the jit tracer, but I couldn’t see how that could be done from the documentation.

As I am trying to visualize information about the network, I don’t mind doing a trace for each run through the forward pass, as I would need to run the visualization algorithm each time anyways.

I understand there might not be a pretty way to do it, but I would appreciate any idea that could potentially work!

1 Like

Update for those wanting to do something similar to me.

I found kind of a workaround. I registered a forward hook to all “leaf” modules in the model, that is modules that do not have submodules. The hook/function creates a storage object with information (module, in data, out data), this object is put into a global list.

The list if updated in order of the forward pass. Once a forward pass is complete, I have a list of storage objects in order of the forward pass with all the info I need.

There are a few clear weaknesses however, this method can only register hooks on modules that have been initialized. If the model uses i.e. torch.nn.functional.max_pool2d or any other method from torch.nn.functional in the forward method of the model, they won’t show up in model.modules() and won’t get a forward hook.
Secondly, the global list may cause several problems: Unsure how it will work in a parallel setup (multiple GPUs) or even just networks that have branching such as skip connections.

2 Likes

Could you get this to work? As you mention, it won’t register non module calls. The Jit tracer does give code that would still require further parsing to be useful (How to parse torch jit trace output?).