Partial derivatives like tf.gradients stop_gradients parameter

Hi all,
Is there a way to calculate partial derivatives instead of total derivative like tf.gradients with parameter stop_gradients https://www.tensorflow.org/api_docs/python/tf/gradients

a = torch.tensor(1., requires_grad=True)
b = 2 * a
c = a + b
print(torch.autograd.grad(c, a, retain_graph=True))

Output (tensor(3.),), Expected output: 1 for partial derivative.
I know we can use b = 2*(a.detach()). But what if d=2*c. In that case, I want to keep calculating the total derivative dd/da. I want to get dd/da=6 without a.detach instead of 2.

Background: I am trying to calculate the last hidden layer partial gradient w.r.t the the first hidden in RNN. Then I still need to calculate loss.backward() where I don’t want to detach the hidden layer.

Thank you all.

I think the reason TensorFlow provides the stop_gradients for tf.gradients is mainly to “provides a way of stopping gradient after the graph has already been constructed”.

However, PyTorch builds the graph dynamically. But I think this stop_gradients feature is useful when we want to change the retained graph.

Regarding your scenario, I think something like this may work:

import torch

a = torch.tensor(1., requires_grad=True)
b = 2 * a
c = a + b
c_detach = a + b.detach()
d = 2 * c
loss = d.sum()

print(torch.autograd.grad(d, a, retain_graph=True))
print(torch.autograd.grad(c, a, retain_graph=True))
print(torch.autograd.grad(c_detach, a, retain_graph=True))
print(a.grad)
loss.backward()
print(a.grad)

Hi Shengwei,
Thank you for the answer. But I think the method of creating an extra c_detach variable is only work for this simple case. For Pytorch built in nn.RNN, we can’t do this to hidden states unless we build the structure from scratch.

Yes. But we might be able to add a hook to the hidden layer to do that.

Oh yeah. I didn’t investigate hook yet. I will look at it. Thank you.