Can I use hooks to check how tensors move between layers

I’ve tried my hands on hooks, and get the basic idea how to add hooks (I’m using forward hooks) to each layer/module. This allows me, e.g., to print the input and output shapes of the tensor coming and and going out of the layer.

Is there also a way to tell from which module a subsequent module got its input?

I’m not sure if there would be any proper way to determine the previously used module using hooks as even checking the .grad_fn wouldn’t be sufficient as some layers would return the “lower level” operation and not their layer name:

bn = nn.BatchNorm2d(3)
x = torch.randn(1, 3, 24, 24)
out = bn(x)
print(out.grad_fn)
# <NativeBatchNormBackward0 object at 0x7fa1d8612730>

lin = nn.Linear(10, 10)
x = torch.randn(1, 10)
out = lin(x)
print(out.grad_fn)
# <AddmmBackward0 object at 0x7fa1d86d6a30>

lin = nn.Linear(10, 10, bias=False)
x = torch.randn(1, 10)
out = lin(x)
print(out.grad_fn)
# <MmBackward0 object at 0x7fa1d8667610>

Your use case comes close to tracing the actual operations and @albanD shared a few code snippets using the __torch_dispatch__ mechanism here.
This tracing tensor code would still record the actual ops, but maybe could be manipulated to record module types instead.
Let me know if this approach would even work or if you depend on hooks.

1 Like

@ptrblck , thanks! I thought this would be relatively easy to do since I stumbled upon various visualization methods. I assume they need to know which module feeds into which to draw the edges between modules.

Of course, I don’t even know if those visualization packages utilize hooks or do something else. Hooks where just my first guess. I will need to check the links you posted.

I believe libraries visualizing the computation graph would also trace or script the model internally (but I might also be wrong).
@tom created this great notebook a while ago which does exactly this.
Based in the code the model is also traced and I assume this is a requirement.

2 Likes