a = torch.tensor(1., requires_grad=True)
b = 2 * a
c = a + b
print(torch.autograd.grad(c, a, retain_graph=True))
Output (tensor(3.),), Expected output: 1 for partial derivative.
I know we can use b = 2*(a.detach()). But what if d=2*c. In that case, I want to keep calculating the total derivative dd/da. I want to get dd/da=6 without a.detach instead of 2.
Background: I am trying to calculate the last hidden layer partial gradient w.r.t the the first hidden in RNN. Then I still need to calculate loss.backward() where I don’t want to detach the hidden layer.
However, PyTorch builds the graph dynamically. But I think this stop_gradients feature is useful when we want to change the retained graph.
Regarding your scenario, I think something like this may work:
import torch
a = torch.tensor(1., requires_grad=True)
b = 2 * a
c = a + b
c_detach = a + b.detach()
d = 2 * c
loss = d.sum()
print(torch.autograd.grad(d, a, retain_graph=True))
print(torch.autograd.grad(c, a, retain_graph=True))
print(torch.autograd.grad(c_detach, a, retain_graph=True))
print(a.grad)
loss.backward()
print(a.grad)
Hi Shengwei,
Thank you for the answer. But I think the method of creating an extra c_detach variable is only work for this simple case. For Pytorch built in nn.RNN, we can’t do this to hidden states unless we build the structure from scratch.