I have a network with a particular output layer needs to be trained independently than the rest of the network. What I’ve done so far is the following.
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
optimizer_st = optim.Adam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
....
# backpass and check the grad norm for spec losses
loss.backward(retain_graph=True)
optimizer, current_lr = weight_decay(optimizer, c.wd)
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
# backpass and check the grad norm for stop loss
stop_loss.backward()
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
It actually works so far but it slows down the network a lot. Is there any better way to do that I am probably missing?