Traversing the recorded graph manually

I’d like to explore the computation graph recorded by autograd manually. I’ve been playing around with next_functions, which seems to be relevant to my needs:

x = torch.ones(4, requires_grad=True)
loss = x.pow(2)
print(loss.grad_fn)
print(loss.grad_fn.next_functions[0])

But I’m not sure I completely understand what’s going on. Each element of next_functions is a tuple of a function object and a number. I don’t know what those numbers are. Also, I haven’t found a way to access the input tensor x just by accessing attributes or calling methods on loss. How can I do it?

1 Like

Hi,

  • The numbers correspond to the index of the argument for this input: For a function with forward(cts, in1, in2), the edge that gave in1 will have number 0 and the edge that gave in2 will have number 1.
  • In the computational graph, inputs do not really exist. It is just a graph of Functions. Each Function is responsible for saving whatever state it will need for its backward.
  • For the particular case of Functions created in python, you have the saved_tensors field to get things that were saved during the forward. For Functions created in cpp, there is usually no api to access the saved elements from python.

If I create a CustomFunction, what is shown in output_tensor.grad_fn is CustomFunctionBackward, I couldn’t find a way to access CustomFunction from the output tensor, is it possible / useful? Despite that, using this backward function that appeared out of nowhere I can access saved_tensors if I saved them in CustomFunction.forward, which is the kind of thing I wanted to achieve, but like you said I couldn’t access the tensors of built-in functions, unfortunately.

Is there any documentation about how / why this backward class (CustomFunctionBackward) is created? Also, is it reasonable to think of the first argument ctx in CustomFunction.forward as an instance of CustomFunction? I believe in an old version of pytorch it used to be self, but I’m not sure how it works now. Is there some special context class that is passed at some point during the forward / backward pass?

CustomFunctionBackward is created to represent the state of your Function. An instance of it is created during the .apply() call and is passed as the ctx (though the fact that the context is of this class is not important).
The relevant code is the FunctionMeta in torch/autograd/function.py and THPFunction_apply() in torch/csrc/autograd/python_function.cpp.

@albanD Hi,I have a question.torch.nn.Parallel.Scatter also has apply function,Why after pytorch call Scatter.apply function,the source code will get into Scatter.forward function rather than Scatter.backward in forward process?

Hi,

When you execute a Function, it runs it’s “forward” and then optionally creates a node that contains the “backward” version of that Function.
For that newly created Function, it’s “forward” is actually Scatter.backward.

Hi,I am sorry I can’t understand.How to explain that BackwardCFunction implements apply function like this.


It call backward function?I really want to know about it.Does the apply function be called in Variable._execution_engine.run_backward function?

And I have another question.I know that pytorch use DistributedDataParallel to get distributed train.I read the source code about Recucer.It has pending variable to synchronize multiple GPU cards in one node.But how to synchronize differnent nodes?For example,There are three node. The first node finish one step gradients calculation.How does the first node know whether the other nodes finish gradients calculation or not?
Thank you very much!These questions has been bothering me for days.

Hi,

The question about the apply is answered here: How does pytorch implements backward process?

I am not very familiar with DistributedDataParallel unfortunately, but it is most likely just uses the autograd logic that makes sure that multiple use of a single Tensor will accumulate all the gradients because performing the next step. Basically, the op before the Reducer knows that 3 copies exists and so will wait 3 completion before executing.

Hello,

For vizualisation purposes, I would really like to be able to access the saved tensors (ie. access the ctx argument in input of forward and backward of a Function if my understanding is correct).

Is there really no way (even a little hacky) to access the saved elements in python for Function created in cpp (which is the case of most functions) ?

Thank you much,

There is no current way I’m afraid.
This might be possible to add this but that would be a lot of work with the code that automatically generates these cpp Functions.

I think the thing you need is variable:

a = torch.randn(1, requires_grad=True)
b = a*(a+2)
print (b.grad_fn.next_functions)
print (b.grad_fn.next_functions[1][0].next_functions)
print (b.grad_fn.next_functions[0][0].variable)

This does not work. I got error that specific function has no attribute called variable

Hello,

Is there method to access and manipulate saved tensors for python??

Hi,

No there is none.

Also the function has a variable attribute only if it is an AccumulateGrad Node.

Hi,
Thanks for your reply I have posted a detailed question in this post please can you look at this discussion and tell me if there is any way to accomplish the task.

Thank you.