Implementing custom Backward pass with many to one backward graph

I am trying to implement this paper : https://papers.nips.cc/paper/7454-collaborative-learning-for-deep-neural-networks.pdf

The proposed model in the paper has multiple classifier heads with shared layers between multiple classifiers, so the backward pass becomes a many to one graph.

How do I go forward with implementing such kind of backward() function ?

I am not very familiar with the paper but computing once the shared part and then using the same features for each classification head will work. You then sum all the losses you want and call backward on it !
Is there some complications Iā€™m missing here?

1 Like