Hello.
I am curious about the implementation of backward in PyTorch.
As far as I know, every forward function object generated will have a corresponding backward function object. What I am confused about is the order of execution those backward function object.
Start from the code loss.backward()
, we need to traverse the computation graph in a order that a backward function can’t be executed until all of it’s input have been generated. The problem is how to determine the order in a efficient way, I search online, however maybe because of suboptimal key word, I can’t get a pleasing answer.
I think the following two method can solve the question, but I want to know about the Pytorch inplementation:

stack based solution. Suppose variable
a
rely onb
, thena
can’t backward until all of the variables thatb
rely on have done thebackward
. This obviously can be solved recursively. However, when the computation graph is huge and complex, will it causestack overflow
? 
counter based solution. I think we can determine the order of backward by the order of forward as these two procedure can be symmetric. Therefore, one can assign a counter to every forward function object based on the order of execution. When the code
loss.backward()
comes, he can do the backward based on the inverse order of counter. I think this solution maybe more memory efficient… However, it is based on the assumption that the computation flow is sequential, which may hinder the parallel execution of some irrelevent forward and backward function…
Now, can anyone tell me something about PyTorch’s solution ? Thank you !