Backpropagating the loss through side network

I have a Main Network (large transformer) and then I add Side Network.
My goal is to use the Side Network for fine-tuning and not update the weights of Main Network.
One way to do this is to freeze the Main Network weights so the weights are not updated.
But the problem is that even if the weights are not updated, the backpropagation graph (for Main Network + Side Network) is still computed which is time and memory-consuming. As I have direct connections between side network modules. I want to only backpropagate through the side network.
Currently, I have just one loss loss=Y-label. Is there a way to use the same loss for backpropagation of ONLY side network?

look at the detach function:

Thanks for your reply.
I checked it before, and apparently, it does not seem to have enough explanation.

Using detach with Main Network MainNetwork.detach() would exclude it from the computation graph when doing the backpropagation on loss. Is my understanding correct?

Take this example code below. In this example, when you do Y.backward(), then it wont backprop through MainNetwork

output = MainNetwork(X)
Y = SideNetwork(output)

@smth Thanks for the extra information.
I have two concerns with this:

  1. The code seems to be considering a series connection of the Main and Side Networks but in our case, it is parallel (with intermediate connections).
  2. In our case, although the MainNetwork and SideNetwork are shown separately, the SideNetwork layers are defined within the MainNetwork’s forward function, because they have to process the same input.

Well, by using the detach_() function at the MainNetwork’s output, the MainNetwork seems to have been excluded from the computation graph and seems to be only backpropagating through the layers of SideNetwork. But, do you have any comments on this?

Here are the considerations you want to make when fine-tuning a side network:

  1. Only pass the side network into the optimizer. In this way, only the side network has all of the optimizer parameters tracked.
  2. Use model.eval() on the main network, so you turn off any Dropout and batchnorm layers operate in eval mode.
  3. You might use the with torch.no_grad(): wrapper on the main network forward pass or .detach_() on the outputs, as @smth mentioned.
1 Like