In pytorch, why is grad_fn stored in next node, since it uses the previous nodes value to calculate grad and updates the grad in the previous node

In pytorch, why is grad_fn stored in next node and not in the previous node? I think it should be stored in the previous node, since it uses the previous nodes value to calculate grad and updates the grad in the previous node.

For example, suppose we have y = x**3

Then dy/dx = 3x*2

This is stored in the grad_fn of y. But it requires the value of the previous node, x, and is used to update the grad property of the previous node, x.

One possible argument I see is that x may be used in multiple separate paths. But in that case I don’t see why there cannot just be multiple grad_fn stored in x, and when updating just sum the grad from each grad_fn.

Hi Liyuan!

This is not true in general. (It’s probably true in your case.)

Consider:

y = a * x (where a.requires_grad = False)

Here the tensor a is needed to compute y.grad (but neither x nor y is needed).

y = x**2

Here x is needed (more or less as in your example).

y = exp (x)

Here y is needed for y.grad (and x is not needed).

So the tensor(s) needed for backpropagation are often not naturally part of the “previous node.”

Let x and y both carry requires_grad = True and z = x @ y.

Backpropagation from z flows back to both x and y. The grad_fn for the z node is
grad_fn=<MmBackward0>. Pytorch chooses to store grad_fn once in node z, rather
than twice in both nodes x and y.

Last, I believe (but am not certain) that grad_fn contains the link(s) to the next (upstream)
node(s) in the computation graph. Since the computation graph is traversed from “result
node” (downstream) to “argument node(s)” (upstream), backpropagation needs grad_fn
to be associated with “result node” in order to be able to find the “argument node(s).”

Best.

K. Frank