Hello,
I am trying to figure out a way to analyze the propagation of gradient through a model’s computation graph in PyTorch. In principle, it seems like this could be a straightforward thing to do given full access to the computation graph, but there currently appears to be no way to do this without digging into PyTorch internals. Thus there are two parts to my question: (a) how close can I come to accomplishing my goals in pure Python, and (b) more importantly, how would I go about modifying PyTorch to suit my needs?
Here is my use-case: given a torch.Tensor
, say L
, that represents the loss of a model on a batch of training data, I would like to (a) recover the entire computation graph of which L
is the root, (b) after running L.backward()
, access the gradient accumulated at every vertex in the computation graph, and (c.) access the forward values stored at each vertex, because I need to re-compute the gradient at certain nodes (matrix multiplications in particular) in order to determine how much each of the incoming gradients propagated to the outgoing gradients.
Here is my problem: although I am able to achieve one or two of these goals individually, there is no way to achieve all three at the same time.
The first thing I tried was to use L.grad_fn
and L.grad_fn.next_functions
to recover the computation graph. This worked just fine, but there is no way to access the forward or gradient values at each vertex. The vertexes in the graph are “backward function” objects, not Tensor
s. Actually, based on my research so far, it seems like all my problems would be solved if only there were a way to reference the original Tensor
expression that corresponds to the output (or inputs) of each backward function.
Anyway, each backward function object does give you a function to compute the gradient of the loss with respect to that vertex’s inputs, so I went a step further and re-implemented backpropagation in Python so I could achieve (a) and (b) at the same time. However, I am still unable to achieve (c.).
I am aware that PyTorch gives you an option to preserve the computation graph and its gradients after running .backward()
, but the gradients are only accessible via the Tensor
objects that make up the computation graph, not the backward function objects. So, without a way to get a representation of the computation graph as Tensor
objects, I’m pretty much stuck. The reason it is important to me that this works given a single root expression L
is that the code for constructing the Tensor
objects that make up the model are usually buried inside of torch.nn.Module
objects. My end-game is to be able to take a third-party implementation of a model and analyze its gradients without needing to modify any code.
I’m open to hacking my own version of PyTorch that would support this. How would I go about adding an .output_tensor
property to each backward function object that would allow me to access the forward values and gradients from the computation graph?
Thanks,
Brian