I have a Main Network (large transformer) and then I add Side Network.
My goal is to use the Side Network for fine-tuning and not update the weights of Main Network.
One way to do this is to freeze the Main Network weights so the weights are not updated.
But the problem is that even if the weights are not updated, the backpropagation graph (for Main Network + Side Network) is still computed which is time and memory-consuming. As I have direct connections between side network modules. I want to only backpropagate through the side network.
Currently, I have just one loss loss=Y-label. Is there a way to use the same loss for backpropagation of ONLY side network?
Thanks for your reply.
I checked it before, and apparently, it does not seem to have enough explanation.
Using detach with Main Network MainNetwork.detach() would exclude it from the computation graph when doing the backpropagation on loss. Is my understanding correct?
@smth Thanks for the extra information.
I have two concerns with this:
The code seems to be considering a series connection of the Main and Side Networks but in our case, it is parallel (with intermediate connections).
In our case, although the MainNetwork and SideNetwork are shown separately, the SideNetwork layers are defined within the MainNetwork’s forward function, because they have to process the same input.
Well, by using the detach_() function at the MainNetwork’s output, the MainNetwork seems to have been excluded from the computation graph and seems to be only backpropagating through the layers of SideNetwork. But, do you have any comments on this?
Thanks!