Capturing Forward calls

My other thread got no traction, so I’m looking at alternative ways of doing this.

For module derived layers, I can hook into register_forward_pre_hook and register_forward_hook and everything is great!

But I’m trying to capture the little non-module derived functions that frequently pepper these networks. Stuff like slicing, *, /, +, -, and even stuff like torch.cat, etc.

The problem is, those functions have no hooks to hook into.

I tried Monkey Patching, but I am currently stumped with getitem (aka slicing) sometimes returning a different shape when patched vs not.

I’ve looked at Onnx, and it is all there, but way more detail than I want, and it makes it hard to actually see what the operations are.

Tensorboard.SummaryWriter.add_graph appears to just take the Onnx (which I think is torch.jit anyways) and makes a pretty picture (which is also not what I’m after)

I guess one thing could be - does every functional operation of a nn.module sibling? If that’s the case I could just rewrite the networks to be only using nn.module derived classes. I’m guessing this is not the case.

Looking for ideas.