Tensor Tracking

I am using torch._dynamo and torch.compile to generate FX graph while training a resnet_50 model. I have attached a screenshot of the FX graph. This graph is showing torch operations and tensors used as parameters and output tensors for each torch operation. I have used red boxes to mark a few tensors. How do I track these tensors while I am training a model using pytorch?

I think that maybe hooks are what you’re looking for.

When I will do the training, I do not know which tensors will get created and where in pytorch code. How do I attach hook for tensors? Can you guide me?

You can register hooks to any nn.Module to see which tensor goes in and what tensors goes out. Here is an example to register a hook to all modules.

self.hooks = {}
for name, module in model.named_modules():
    hooks[name] = module.register_forward_hook(self,hook_fn)

The hook_fn that gets passed can look something like this:

def hook_fn(module, input, output):
    print(input.shape)
    print(output.shape)
2 Likes

Has your problem been solved? I also want to know how to track tensor in torch.compile mode.

@vdw Is it possible to identify tensor types such as whether a tensor is a weight tensor, or bias tensor, or input tensor or activation tensor?

@ Gu Wei, we can use hook with the model. But I am not sure how we could track with torch.compile.