Track .grad gradient graph

Maybe the title is confusing. Here is the small example:

w = torch.tensor([3.2], requires_grad=True)
p = torch.tensor([2.0, 1, 7], requires_grad=True)
g = torch.sum(p ** 2)
e = w * g
e.backward(retain_graph=True)
f = p.grad

l = e + 0.5 * f.mean()
l.backward()

here f is the funciton of w, however we loss gradient graph of f ( only tensor), we know f=2pw actually. when we do backward() for l, we want f contributes to gradient of w (df/dw = 2p). How to make it here?

Thanks.

Hi,

You want to do e.backward(create_graph=True) that will make sure the backward pass run in a differentiable manner.

Also creating a .grad that requires grad is usually not recommended because when you do l.backward() here, you will accumulate in the same p.grad which can easily become confusing. You can do the following to make it simpler:

w = torch.tensor([3.2], requires_grad=True)
p = torch.tensor([2.0, 1, 7], requires_grad=True)
g = torch.sum(p ** 2)
e = w * g
f = autograd.grad(e, p, create_graph=True)[0]

l = e + 0.5 * f.mean()
l.backward()

Hi, albanD! Thanks for your help. However I am so confused about the second derivative of f with respect to w ( the most right part of the graph). What is expandbackward here? what about the label 0 and 1 after backward? Are Mulbackward0 the same for both path?

Hi,

Yes these graphs are not very easy to read usually.
These come from the graph created inside the backward. The 0/1 values here are used to differentiate the different overloads of a function (mult(t1, t1) vs mul(t1, 2) for example).

Yep. Could you please help me explain the right most backward flow step by step for this example? like MeanBackward0 corresponding to f.mean(). I just fell into the trouble to find 1 to 1 correspondence.

Well the problem is that the 1-1 correspondence is wrt to the ops executed during the backward pass. So not the ops you wrote.

For example, the ExpandBackward comes from the backward formula of sum that uses expand.
MulBackward1 and the PowBackward on the right side come from the backward pass of the power function that depends on its input.

I see. Some backward are not intuitive but reasonable in pytorch records. Thanks for your help.

How did you generate this graph?

You can find this tool here: https://github.com/szagoruyko/pytorchviz

1 Like

Save my day by looking into the source code. Much appreciated.