Inspecting gradients of a Tensor's computation graph

Hello,

I am trying to figure out a way to analyze the propagation of gradient through a model’s computation graph in PyTorch. In principle, it seems like this could be a straightforward thing to do given full access to the computation graph, but there currently appears to be no way to do this without digging into PyTorch internals. Thus there are two parts to my question: (a) how close can I come to accomplishing my goals in pure Python, and (b) more importantly, how would I go about modifying PyTorch to suit my needs?

Here is my use-case: given a torch.Tensor, say L, that represents the loss of a model on a batch of training data, I would like to (a) recover the entire computation graph of which L is the root, (b) after running L.backward(), access the gradient accumulated at every vertex in the computation graph, and (c.) access the forward values stored at each vertex, because I need to re-compute the gradient at certain nodes (matrix multiplications in particular) in order to determine how much each of the incoming gradients propagated to the outgoing gradients.

Here is my problem: although I am able to achieve one or two of these goals individually, there is no way to achieve all three at the same time.

The first thing I tried was to use L.grad_fn and L.grad_fn.next_functions to recover the computation graph. This worked just fine, but there is no way to access the forward or gradient values at each vertex. The vertexes in the graph are “backward function” objects, not Tensors. Actually, based on my research so far, it seems like all my problems would be solved if only there were a way to reference the original Tensor expression that corresponds to the output (or inputs) of each backward function.

Anyway, each backward function object does give you a function to compute the gradient of the loss with respect to that vertex’s inputs, so I went a step further and re-implemented backpropagation in Python so I could achieve (a) and (b) at the same time. However, I am still unable to achieve (c.).

I am aware that PyTorch gives you an option to preserve the computation graph and its gradients after running .backward(), but the gradients are only accessible via the Tensor objects that make up the computation graph, not the backward function objects. So, without a way to get a representation of the computation graph as Tensor objects, I’m pretty much stuck. The reason it is important to me that this works given a single root expression L is that the code for constructing the Tensor objects that make up the model are usually buried inside of torch.nn.Module objects. My end-game is to be able to take a third-party implementation of a model and analyze its gradients without needing to modify any code.

I’m open to hacking my own version of PyTorch that would support this. How would I go about adding an .output_tensor property to each backward function object that would allow me to access the forward values and gradients from the computation graph?

Thanks,
Brian

Any ideas?

I’ve been looking at this to get me started:

Thanks!

Hi,

The problem here is that most Functions are actually implemented in cpp and called automatically. There is not way to get some code executed everytime a Function is forwarded (or the backward created).

In your case, you can also say that you only support nn.Module based models. This ways, you can look into the model that is given to you and use .register_forward_hook to get both the inputs and outputs of this Module.
From there, you can .register_hook on the tensors you want to get their value and gradients.
Doing that for all relevant modules like Conv and Linear will give you something similar to what you want no?

Hi @albanD ,

Thanks for the idea! I didn’t realize you could access the input and output tensors that way. So your suggestion is to loop over all of the .modules() and register a hook for each of them. That would give me access to at least some of the tensors buried in the computation graph. And from there it would be easy to connect the tensors to the nodes in the .grad_fn computation graph I built using a dict. Restricting support to nn.Modules is certainly an option in the near term. Might be worth looking into.

I’m fairly certain that I will inevitably need to modify the PyTorch source in order to do exactly what I want. My goal is to find the path in the computation graph through which the most gradient flows for each each parameter. The tricky part is that, in order to do this correctly, the gradient needs to be known for every part of the computation graph, all the way from the loss node to the parameter nodes. The other tricky part is that every tensor operation (particularly matrix multiplications) encapsulates the equivalent of multiple scalar operations in the computation graph. So looking purely at the gradient tensors flowing in and out of each tensor operation is not enough – you need to have knowledge specific to each tensor operation to figure it out. That’s why I’m interested in accessing the forward value of each tensor operation – so I can recompute the scalar terms of the gradient and pick the ones with the highest magnitude.

Is there someone familiar with the source code who would know how to add the .output_tensor property I mentioned?

Thanks!

All the Function have a .metadata field that is a dictionnary that you can use to save anything you want associated with it.
The problem is that the Functions that you encounter in the computational graph are the ones implementing the backward. These are created during the forward pass of the forward ones and given everything they need.
The thing is that these forward Functions are mostly created automatically based on this derivatives.yaml file. The other ones are created by the files in this folder. And the python created Functions are handled by the python_function.cpp file and some python class magic from FunctionMeta.
You would need to modify the metadata field of the newly created backward function in each of these 3 places. The last one should be fairly easy, the second one is going to be annoying as you will need to change the code of every single Function declared in that folder and for the first one, you will need to dive into the automaticly generated code and handle all the weird cases for all Function (the ones that do inplace op and don’t have output, or ones that have outputs called out or self…).

1 Like

Thanks, these pointers are extremely helpful! I’ll take a look and let you know how it goes.

Brian

1 Like

I think I see what you mean by needing to modify each function definition. For now, I’m mostly interested in re-computing the gradient for affine transformations (the addmm(self, mat1, mat2) operation).

Ignoring the alpha and beta parameters for a moment, I noticed that all of the information you need to re-compute the gradient is stored in SavedVariables. Is there a way to access the SavedVariables through the Python API? I ask because the Function object has a saved_variables/saved_tensors property defined in python_function.cpp. However, this property does not seem to exist whenever I try to access it on a Backward object, which is a Function object as I understand it.

These properties are set when you use save_for_backward() during the forward pass.
This should only be used for input/output Tensors. Other things should be saved in ctx.my_tensor=my_tensor.

SavedVariable is any Tensor that is somewhere in the graph.
But here, since every cpp implemented Function is different, there is no common API to expose what they saved. Such API would be tricky as Functions save very different things and you would have no guaranty when you have a Function what you would find there. Also this would create even more headaches about cycle references of python objects :wink:

1 Like