I’m learning about
autograd. Now I know that in y=a*b, y.backward() calculate the gradient of a and b, and it relies on y.grad_fn = MulBackward. Based on this MulBackward, Pytorch knows that dy/da = b and dy/db = a. And we can check the gradient values by
My question is that how exactly different grad_fn (e.g., AddBackward, MulBackward…) calculates the gradients? Thanks.
As a quick illustrative example, let’s start with two abstract functions, c = f(a, b), and e = g(c, d).
If we know dg/de and dg/dd and dc/da and dc/da, we can compute dg/da and dg/db, because of the classic chain rule:
de/da = (de/dc)(dc/da) + (de/dd)(dd/da) = (de/dc)(dc/da)
de/db = (de/dc)(dc/db) + (de/dd)(dd/db) = (de/dc)(dc/db)
Note that de/dc will have an explicit concrete value without needing to know dc/da or dc/db as we traverse backward through the graph. This de/dc is the
grad_output as we view the computation graph from the perspective of f(a,b). We multiply this value with dc/da to get de/da and dc/db to get de/db. In the concrete case where f(a,b), this would be multiplying with b to get de/da and multiplying with a to get de/db. Note that this pattern is general. It does not matter if the graph extends past e(c, d) with more functions, or if there are previous functions before c = f(a, b), because this pattern can be extended until all gradients are computed.
You might want to take a look at some references on reverse-mode automatic differentiation, which is the typical way this is done on most frameworks (Reverse Accumulation) in this article : Automatic differentiation - Wikipedia. The CS231n also has some great hands on examples with a “real” computation (matmul) at the end. CS231n Convolutional Neural Networks for Visual Recognition
Hi, thank you for your help. What do you mean by
Note that de/dc will have an explicit concrete value without needing to know dc/da or dc/db as we traverse backward through the graph.
Do you mean: if
g() here is multiplication, then pytorch knows that the derivative of multiplication must be
de/dc =b, so
As we go backward through the computation graph, we can compute de/dc without knowing anything about dc/da or dc/db as e = g(c, d) comes after a and b.
Yes, that is the critical part. In order for autograd to work, every supported op must have a backward function (or more than one depending on the number of inputs) defined for this purpose. The function doesn’t store b necessarily, it just knows that the derivative is the other input. If we rewrite c = f(a, b) for a moment to be output = f(input, weight), there are definitions for “how to compute input gradient given output gradient and weight” and “how to compute weight gradient given output gradient and input.” While this grad_fn itself doesn’t store the values, we will need to pass them to grad_fn to compute the gradients.