Training multiple heads in a hierarchy

Hi!
I am trying to train a multiple heads using a ResNet backbone. Currently I am training the ResNet + Head0 first then (using outputs of ResNet backbone) training the other head, Head1. I am doing it like this because I do not want Head1 to affect the backbone at all, however, as I am unable to train them simultaneously this is taking longer and I am unable to observe the convergence characteristics of Head1 when the backbone has not converged fully.
I wanna train (Backbone+Head0) and (Head1) as asynchronously as possible, with Head1 not causing Backbone’s parameters to be updated but only accessing those to read and perform inference to use as the input of its training.
Even though the problem is easy to define in terms of its mathematics, I do not even know where to start.
Can you please help me?

You can .detach() the output of the backbone before passing it to Head1 and can directly pass it to Head0.

1 Like