I have AutoEncoder consisting of encoder and decoder. I want to reconstruct image. So I am optimizing parameters of both encoder and decoder but at the same time I need to use encoder layers (Usually you use VGG) for computing reconstruction loss. But for second encoder pass, I dont want to consider gradient. torch.no_grad() completely discards graph. Here is what code should have looked like:
latent = encoder(x)
x2 = decoder(x)
with torch.no_grad():
latent2 = encoder(x2)
loss = F.mse_loss(x2, x) + F.mse_loss(latent2, latent)
Problem is not that its not possible. It can be done by just copying encoder but that’s memory and compute inefficient as its being done for each iteration.
This way decoder will not be optimized for latent loss. Only MSE loss will be used to optimize the decoder.
Lets suppose I make a copy of encoder in each iteration. Here are computation graphs:
x → encoder → latent
x → encoder → decoder → x2
x → encoder → decoder → encoder_copy → latent2
Optimizer has only encoder and decoder parameters and not encoder_copy parameters. So the gradient will be calculated for parts that are inside square brackets:
x → [encoder] → latent
x → [encoder → decoder] → x2
x → [encoder → decoder] → encoder_copy → latent2
encoder_copy is not being optimized. I dont want to copy encoder in each iteration. But I will end up calculated gradient for encoder in second forward pass which I dont want (it will cause model to diverge). Here is what I end-up with
x → [encoder] → latent
x → [encoder → decoder] → x2
x → [encoder → decoder → encoder] → latent2
But what I want:
x → [encoder → decoder] → encoder → latent2