You can register hooks to any nn.Module
to see which tensor goes in and what tensors goes out. Here is an example to register a hook to all modules.
self.hooks = {}
for name, module in model.named_modules():
hooks[name] = module.register_forward_hook(self,hook_fn)
The hook_fn
that gets passed can look something like this:
def hook_fn(module, input, output):
print(input.shape)
print(output.shape)