How to train only middle component?

# train encoder and classifier.
out = encoder(input)
pred = classifier(out)
ce_loss = ce(pred, target)
total_loss += ce_loss

# In this train phase, I want to train only encoder_2
out = encoder(input.detach())
out_2 = encoder_2(out)
pred_2 = classifier(out_2.detach())
ce_loss2 = ce(pred_2, target)
total_loss += ce_loss2


Flow: Encoder1 -> Encoder2 -> Classifier.

I want to train Encoder1 and Classifier with only ce_loss and train Encoder2 solely with ce_loss2.
In this case, how to freeze encoder and classifier when I train Encoder2 ?

I don’t know which one should I use [detach, no_grad(), param.requires_grad = grad_on] ?


I think there are a few related threads