How does PyTorch module do the back prop

As I have posted here, I’m kind of confused as well as amazed by how PyTorch Module does the back prop. It doesn’t seem like we have linked the LinearFunction used in the forward function anywhere in this example. Could anyone kindly explain to me how Module figures out the actual backward function to call? Thanks!

Whenever you do an operation with Variables as input, PyTorch builds up a graph of the computations involved and stores any intermediate values needed for backpropagation.

For example

>>> a = Variable(torch.randn(1), requires_grad=True)
>>> b = a*a
>>> print(b.grad_fn)
<MulBackward1 object at 0x7f0f6a6d2a90>

As you can see the Variable b has a .grad_fn attribute in which PyTorch has stored the details of the calculation used to produce the value stored in b.

I see. So, I guess it’s the Variable object that keeps track of the backward function. Still, I’m not sure how exactly the Module and Function class achieved this, since there is no explicit registration for the backward function. So far I found FunctionMeta might be the place all these are going on, but not quite sure exactly how.

TBH, I don’t have much idea how it really works. I just know that if I operate on Variables and avoid inplace operations, then backprop works.

Some pointers…

  • The Module class has nothing to do with building the computation graph.
  • If I do a*b then python calls a.__mul__(b) which must be somehow defined on the Variable class though I haven’t found where.
  • Operations like torch.matmul also work on Variables, so they might provide clues.

Be aware that Variables and Tensors are being merged in the current master branch, so if you are digging around in the source code, you might find it easier to stick to the v0.3.1 branch

1 Like

also ,there is some functions in ‘grad_fn’,but i think these function effect the backword.

from torch.autograd import Variable
print( b.grad_fn.next_functions)

theoutput of above the code is :slight_smile:

((<AccumulateGrad object at 0x7f4705d6a5f8>, 0), (<AccumulateGrad object at 0x7f4705d6a5f8>, 0))