About the implementation of backward?

Hello.
I am curious about the implementation of backward in PyTorch.

As far as I know, every forward function object generated will have a corresponding backward function object. What I am confused about is the order of execution those backward function object.

Start from the code loss.backward(), we need to traverse the computation graph in a order that a backward function can’t be executed until all of it’s input have been generated. The problem is how to determine the order in a efficient way, I search online, however maybe because of suboptimal key word, I can’t get a pleasing answer.

I think the following two method can solve the question, but I want to know about the Pytorch inplementation:

  1. stack based solution. Suppose variable a rely on b, then a can’t backward until all of the variables that b rely on have done the backward. This obviously can be solved recursively. However, when the computation graph is huge and complex, will it cause stack overflow ?

  2. counter based solution. I think we can determine the order of backward by the order of forward as these two procedure can be symmetric. Therefore, one can assign a counter to every forward function object based on the order of execution. When the code loss.backward() comes, he can do the backward based on the inverse order of counter. I think this solution maybe more memory efficient… However, it is based on the assumption that the computation flow is sequential, which may hinder the parallel execution of some irrelevent forward and backward function…

Now, can anyone tell me something about PyTorch’s solution ? Thank you !

1 Like

PyTorch’s current traversal order is exactly as 2, but prioritizes accumulating gradients step over all other steps.

Great question and thoughts :smiley:

Thanks for fixing this puzzle. :smiley: