How to manually do chain rule backprop?

Hello,

is this approximately what you need: let’s say you have

from torch.autograd import Variable

x = Variable(torch.randn(4), requires_grad=True)
y = f(x)

y2 = Variable(y.data, requires_grad=True) # use y.data to construct new variable to separate the graphs
z = g(y2)

(there also is Variable.detach, but not now)

Then you can do (assuming z is a scalar)

z.backward() # this computes dz/dy2 in y2.grad
y.backward(y2.grad) # this computes dy/dx  * y2.grad
print (x.grad)

Note that the .backward evaluates the derivative at the last forward computation.
(I hope this is correct, I don’t have access to my pytorch right now.)

Best regards

Thomas

2 Likes