Update weights due to gradients coming from one path and not from another


I have a network with 2 parallel heads and each head predicts a different entity and hence they have different losses. I want the backbone to get updated by gradients coming from one head but not the other. How can I do this?

(Attaching a figure explaining the same)


You could .detach() for output of backbone before sending it to Head2, which would cut the computation graph there.

