I am curious about the implementation of backward in PyTorch.
As far as I know, every forward function object generated will have a corresponding backward function object. What I am confused about is the order of execution those backward function object.
Start from the code
loss.backward(), we need to traverse the computation graph in a order that a backward function can’t be executed until all of it’s input have been generated. The problem is how to determine the order in a efficient way, I search online, however maybe because of suboptimal key word, I can’t get a pleasing answer.
I think the following two method can solve the question, but I want to know about the Pytorch inplementation:
stack based solution. Suppose variable
acan’t backward until all of the variables that
brely on have done the
backward. This obviously can be solved recursively. However, when the computation graph is huge and complex, will it cause
counter based solution. I think we can determine the order of backward by the order of forward as these two procedure can be symmetric. Therefore, one can assign a counter to every forward function object based on the order of execution. When the code
loss.backward()comes, he can do the backward based on the inverse order of counter. I think this solution maybe more memory efficient… However, it is based on the assumption that the computation flow is sequential, which may hinder the parallel execution of some irrelevent forward and backward function…
Now, can anyone tell me something about PyTorch’s solution ? Thank you !