Help needed in understanding how autograd works

This is just a super basic question that I have not been able to successfully google an answer to.

I’m trying to understand how the torch.Tensor object stores information on its position in the graph. Since autograd can calculate gradients based on operational relations between torch.Tensor objects, information about these relations must be stored somewhere. Is it an attribute in the Tensor objects that I cannot find, or is there somewhere else I should look?

Hi,

The Tensors are not linked to the internal part of the graph.
Only the final output has a .grad_fn attribute that contains the backward of the Function that created this Tensor.
Then there is a full graph a backwardFunction that represent how the backward pass should be computed. Where each Function points to the next one(s) to call when doing the backward pass. These functions can, depending on their needs, save some Tensors but the Tensors never have a reference to all the Functions that use them.

I hope this is clear, don’t hesitate to ask more questions if it is not.

1 Like

Thank you for the reply.

It certainly made it clearer, but I still have questions:

I am doing my best to find answers by looking through documentation, source code and inspecting objects in Jupyter notebook, but I still cannot find the answers I’m looking for. Forgive my incompetence, I have previously mostly worked with C# and am apparently finding it hard to investigate documentation and source code for Python.

Let’s take a simple example:

import torch
x=torch.randn(4,4)
y=2*x
z = y.sum()
z.backward()

Calling z.backward() has resulted in x.grad being updated. Ultimately, my question is how.

You say that z.grad_fn contains something leading back to x? From what I can see type(z.grad_fn)=SumBackward0, which is not a type/class that I have been able to find source or documentation for. Perhaps it is a subclass to the Function class? Basically I am trying to find some attribute of z which points back to x. If z.grad_fn points back to x, then my question is how/where? Is it possible to inspect z.grad_fn to gain full information on the graph?

Since I am not able to find answers to my own questions, it seems that I must be using the documenation wrong. For instance I find it hard to get information on grad_fn. If you have any sense of how I might be going at this the wrong way, any guidance is appreciated.

1 Like

Hi,

First, this part is not really documented more on less on purpose as the backend is moving quite a lot still. Even though we make sure the python/cpp user facing interface do not change, there are still some changes in the backend and what I’m going to describe below is the current state and may not be true forever !

In you example, I guess you miss a requires_grad=True in the x creation, otherwise, no gradient will be computed.

Then in your example:

  • SB = z.grad_fn is indeed a subclass of Function that corresponds to the backward method that should be called to perform the backward of the sum op that you did.
  • SB.next_functions is a tuple that contains information of what should be called next. Each entry in the tuple contains another tuple with two elements: The next backward function to call and which output of this other function this input corresponds to. In your case, it should contain a MulBackward0 method for the input and links to the 0th output of the previous function. The input is y and it was created by the multiplication by 2 op. This op had only one output y and so it’s the output 0.
  • MB = SB.next_functions[0][0] here we get the MulBackward0 Function.
  • MB.next_functions will contain two things (note that because of the wrapping, the order of arguments here does not match what you wrote in python: x is the first input argument and 2 the second one):
    • The first argument was created by an AccumulateGrad op and is it’s 0th output.
    • The second argument has a None backward method as it is the 2 number and no gradients need to flow back to it.
  • AG = MB.next_functions[0][0] get the AccumulateGrad Function.
  • AG.variable is x -> True This particular subclass of Function has a .variable attribute that contains the Tensor in which the gradients need to be accumulated. In this case it will contain the Tensor x. And so when calling backward, it will accumulate into the .grad field of the x Tensor.

Final note: Most of the Function subclasses are created by automatically generated cpp code. So it is quite tricky to find them. The file used to generate them can be found here, in the link, the backward for the sum is just passing the same gradient value to all the inputs.

I hope this helps :slight_smile:

2 Likes

This explanation solves a lot of my problems, thanks. But I still got a little bit of confusion. Let’s come back to the toy example. I understand now that we can trace back from the next_functions graph-sort of thing to the x. But as we know in this example, x.grad will be a 4x4 matrix of '2’s. How can we retrieve this number ‘2’ information from the backward pass? For example, in the MulBackward, which correponds to y = x*2. How does it know we are using what to multiply what? The x can be further traced from the AccumelateGrad.x, but where is the information of 2? Since MB.next_functions[1] is just (None, 0). There is no Function subclass in it anymore.
So, my actually final quesion is how can we mannually use the z.grad_fn instance to calculate the grads for all the inputs. I just need to see the actual dataflow. Not only the backward functions called, you still need to store the inputs to the backward functions, I think. But I don’t know where to find it.
Another question is can we directly call the grad_fn() to calculate something?
I tried to use grad_fn(), and I clealy see that if I don’t pass the correct shape, it will throw an error. So the Function subclass instance does store some shape information, right? But I don’t see how to access it.

I think you should open a new topic with more details on what you want to do.

The backward Nodes are just called one after the other, piping the output of one as the input to the next.
These Node also contain any data that they need from the forward to perform their computation.