For multitask learning, ie a network has multiple heads, it is useful to lower done the peak gpu memory usage if we can backpropogate the loss into a common node for each indivisual heads before during the forward phase. In order to achieve that, we need to cut the backward into two stage. For example,
import torch
a = torch.rand(4,requires_grad=True)
b = torch.rand(4,requires_grad=True)
c = a + b
d = 2*c
e = d.sum()
#calculate e, d grad
e.backward() ?
#...
#calculate a,b and c grad
c.backward() ?
For the above code, I want to calculate the grad for node d and e first, and in some latter stage calculate the grad for a,b and c. Is there a way to achieve that behaviour?