For a given model I have a normal loss function that is supposed to train the whole net. And I like to use another loss function that only trains just a single output layer. So I don’t like this auxiliary loss to effect the rest of the networks and just train the output layer that is computed on. How can I code this up in Pytorch?
What I tried so far is below however in practice I did not see loss_aux improving.
criterion_main = nn.L1Loss()
criterion_aux = nn.BCELoss()
optimizer_main = nn.optimizer.Adam(net.parameters())
# aux_out is the layer I like to train with criterion_aux
optimizer_aux = nn.optimizer.Adam(net.aux_out.parameters())
# in training loop
loss_main = criterion_main(out, labels)
loss_aux = criterion_aux(out_aux, labels_aux)
The simplest way (but not the most efficient one) would be to create a new variable/tensor with the same data (eg with detach) and then pass it separately through the last layer and apply the loss function afterwards.
A separate optimizer is (from my understanding) not working because the optimizations could lead to different optima and therefore kind of “revert” the steps of the other optimizer
I think what you suggest might be wrong. These two optimizers are scoped with different parts of the network. Yes for the main optimizer, it is given the whole network parameters but the loss is computed over the main output only. Therefore, when I backward() the aux output layer is not updated. So I believe they do not overwrite each other. At least this is what I imagine here.
From what you wrote it seemed to me that you pass your variable through a net (incl aux) and want to update this nets parameters. If you do no pass the variable for the main-forward through out you are right, then the out parameters won’t be updated by the main optimizer
Yeap you are right, to make things more clear I guess I need to post the whole code. However, as i run the code above, I think the network trains as I want. Compared to training the whole net, it performance worse on the aux output, yet better than random.