Tensor Tracking

You can register hooks to any nn.Module to see which tensor goes in and what tensors goes out. Here is an example to register a hook to all modules.

self.hooks = {}
for name, module in model.named_modules():
    hooks[name] = module.register_forward_hook(self,hook_fn)

The hook_fn that gets passed can look something like this:

def hook_fn(module, input, output):
    print(input.shape)
    print(output.shape)
2 Likes