How does pytorch train the network in multiple stages?

For example, my network has two branches, A and B. How to train A first and then B? Specifically, B is fixed when training A, and fixed A when training B. Is there any code similar to this?

You can freeze the parameters by setting their requires_grad to False. Then you can set them to True once to want to update them again

for param in A_branch.parameters():
    param.requires_grad = False
1 Like