Freezing and training subcomponents of the model

Hey folks!

I have a model which consists of 3 submodules( A, B, and C) stacked sequentially. The pipeline of the model looks like A–>B–>C. For the first few epochs, I train the A–>B and compute the loss; thereafter, I freeze the A and B submodules and train only C. The forward pass looks like this:

def forward(self, batch):
   z = B(A(batch))
   if self.current_epoch > 30:
     z = C(B(A(batch))
   .....

def callback(self):
  if self.current_epoch > 30 and self.freeze:
    for param in A.parameters():
      param.requires_grad = False
    for param in B.parameters():
      param.requires_grad = False
   
   self.optimizer[0] = torch.optim.Adam(self.C.parameters(), lr=pl_module.lr)
   self.freeze = False

Note that I also reset the Adam optimizer to train the C component.

I did not get the desired results with this pipeline. However, if I train a model which consists of only A–>B, and store its outputs (z = B(A(batch))) externally as a database after the training, and then use this database to train a model which contains only C, then it works for me. Basically, I have to break the pipeline into two to get results rather than just having a single pipeline. Any suggestions where I might be doing something wrong?

I would recommend running a few sanity check and making sure the frozen parameters are indeed not updated anymore after 30 epochs, that all C parameters get valid gradients and are updated, etc.
If all of this looks OK I would then double check the outputs of B(A(batch)) using both pipelines and a static input (e.g. just torch.ones or one specific sample) and compare these features to see if something still differs in the forward passes.