How do I hook into the backward computional graph?

Suppose I have two intermediate nodes o_1 = f(a, ...) and o_2 = g(a, ...). Both o_1 and o_2 contribute to the final loss. I want to get the gradient flowing to a from o_1 only, and not from o_2. PyTorch:

Suppose I have two intermediate nodes o_1 = f(a, ...) and o_2 = g(a, ...). Both o_1 and o_2 contribute to the final loss. I want to get the gradient flowing to a from o_1 only, and not from o_2. This is the part of the gradient of the loss that flows to a through the node o_1.

To reiterate what I want for more clarity:
When computing gradients, a single node A can have multiple paths to the final gradient target (the loss). The upstream gradients reach A and are accumulated by summation in A to form its final gradient value. I want to hook into this accumulation process and store each upstream gradient that comes. I.e., I want to store the upstream gradient from the different edges flowing from A separately, instead of having the summed.

The only way I can think of is to detach the gradient of o_2, but I actually need to get the gradients of a lot of these edges in the computational graph, and detaching o_2 will make the downstream edges incorrect.

The only way I can think of is to detach the gradient of o_2, but I actually need to get the gradients of a lot of these edges in the computational graph, and detaching o_2 will make the downstream edges incorrect.

Using torch.autograd.grad should work as it would allow you to compute the gradients w.r.t. a specific variable:

f = nn.Linear(1, 1)
g = nn.Linear(1, 1)

a = torch.randn(1, 1, requires_grad=True)

o_1 = f(a)
o_2 = g(a)

g_1 = torch.autograd.grad(outputs=o_1, inputs=a, retain_graph=True)
g_2 = torch.autograd.grad(outputs=o_2, inputs=a)

o_1 = f(a)
o_2 = g(a)
(o_1 + o_2).backward()
print(a.grad)
# tensor([[0.6055]])
print(g_1[0] + g_2[0])
# tensor([[0.6055]])

Thanks! I actually thought of a solution that works:

import torch
from icecream import ic


v = torch.tensor([1., 1., 1.], requires_grad=True)

v_dummy1 = v.view(*v.shape)
v_dummy1.retain_grad()

out = v_dummy1 + v**2
out.backward(torch.tensor([1., 2., 3.]))
ic(v_dummy1.grad, v.grad)
v_dummy1.grad: tensor([1., 2., 3.]), v.grad: tensor([3., 6., 9.])

I don’t want to use autograd.grad, as it mixes the backward computation into the forward definition of the model. I also suspect it’s less efficient when called multiple times.