Torch.autograd.grad for op-by-op backward


I’m trying to use torch.autograd.grad to write backward procedure operator-by-operator. It works fine when there is no same tensor to be used by multiple operators. But I don’t know how to deal with residual connections shown in below (tensor a). Is there any way that I can do backward operator-by-operator (and at last accumulate grad of tensor a by manual code)?

import torch

a = torch.tensor([1,2,3,], dtype=torch.float).requires_grad_()
b = torch.tensor([1,2,3,], dtype=torch.float).requires_grad_()
c = torch.tensor([2,3,4], dtype=torch.float).requires_grad_()
d = c + a
e = d * b
f = e + a
# f = (c + a) * b + a
loss = torch.sum(f)
# loss.backward()
# print(a.grad)
(grad_e, grad_a1) = torch.autograd.grad([loss], [e, a])
print(grad_e, grad_a1)
(grad_d, grad_b) = torch.autograd.grad([e], [d, b], [grad_e])
print(grad_d, grad_b)
(grad_c, grad_a2) = torch.autograd.grad([d], [c, a], [grad_d])
print(grad_c, grad_a2)
grad_a = grad_a1 + grad_a2

Will encounter

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.