Split loss.backward() into two parts

For multitask learning, ie a network has multiple heads, it is useful to lower done the peak gpu memory usage if we can backpropogate the loss into a common node for each indivisual heads before during the forward phase. In order to achieve that, we need to cut the backward into two stage. For example,

import torch
a = torch.rand(4,requires_grad=True)
b = torch.rand(4,requires_grad=True)
c = a + b
d = 2*c
e = d.sum()

#calculate e, d grad
e.backward() ?


#calculate a,b and c grad
c.backward() ? 

For the above code, I want to calculate the grad for node d and e first, and in some latter stage calculate the grad for a,b and c. Is there a way to achieve that behaviour?

Just to clarify, could you please explicitly state in mathematical symbols maybe what values you want?

Do you mean you want to call backward on d and e or you mean calculating the gradient of some tensor with respect to d and e?