Mixing trained and untrained models

If i have a pretrained autoencoder/decoder that i want to use, and i want to train a model that would use that decoder as its output stage, how do i train them properly? I have several questions:

Should i .detach() tensor before feeding it to decoder, so:

def forward(self, input):
  out = self.main(input)
  out = out.detach() # <- should i do this?
  out = self.decoder(out)
  return out

If not,should i set up optimizer to train only new model, or model+decoder to fit them together better?
So if my model has two modules: main and decoder should i do this:
optimizer = optim.Adam(model.main.parameters());
or this:
optimizer = optim.Adam(model.parameters());

It depends on what your loss term is. Is the loss computed directly with the output of the decoder? If so: Don’t detach the tensor

Detaching a tensor is the same as saying “I don’t want gradients to flow through here”. In this scenario, your loss is computed with the out of the decoder, gradients flow through to the start of the decoder and then stops there. Your main module will never see any gradients.

To answer your optimizer question; it depends. There’s no wrong answer really. You may want to train them jointly so that the decoder can learn to work better with your main module. You can also choose to just let the main module train on its own and adapt itself to the already trained decoder. I would personally suggest letting them train jointly (so model.parameters()) but you wouldn’t be wrong in only training the main module. Gradients have to flow through the decoder anyways (this would be a different story if you wanted to freeze the encoder instead), so I’m guessing speed-wise it’ll be the same.

1 Like

Maybe train model first, keeping decoder unchanged, then enable decoder training, so they can better adapt to eachother?

Seems viable to me. That’ll be following the idea of “Learn to work with the decoder first, then we can train you both jointly”. I don’t think there’s a wrong way to approach it; Maybe consider trying out the different methods and seeing if it even makes a difference

1 Like