Update weights due to gradients coming from one path and not from another

Hello,

I have a network with 2 parallel heads and each head predicts a different entity and hence they have different losses. I want the backbone to get updated by gradients coming from one head but not the other. How can I do this?

(Attaching a figure explaining the same)

Thanks!

You could .detach() for output of backbone before sending it to Head2, which would cut the computation graph there.

1 Like