Hello, I think I have in principle a relatively simple problem that seems to be fundamentally impossible to solve in pytorch.
Let’s assume we propagate a batch X two times trough the same network E and then calculate a loss. I am then interested in the gradient with respect to the parameters in E but only with respect to the first pass. In other words, the second pass trough E together with the loss function could be considered as the effective total loss function for the first pass of E.
This is also equivalent to the following situation:
We clone (deep copy) E and call it H. Then we propagate X one time trough E and then trough H, that is, out=H(E(X)) and calculate a loss on its output. Finally, we calculate the gradient w.r.t E only.
But obviously I don’t want to clone a model for each batch, so how can I do this more efficiently?
Thanks for helping me out here!
Ps: Here is a use case in practice. Assume you have an autoencoder consisting of an encoder E: x->z and decoder (or generator) G:z->x. Now suppose you don’t only want to impose a reconstruction loss s.t. x=G(E(x))=x_hat but also a latent consistency s.t. z=E(x)=E(G(E(x)))=E(x_hat)=z_hat. As you can see, in order to calculate z_hat you propagate two times trough the encoder E, but you don’t want the encoder to learn how to map x_hat close to z (that would result in a degenerated solution). What you want is to teach the encoder-decoder pair to reconstruct an x_hat, such that for a frozen encoder E the latent variable z_hat=E(x_hat) is close to z=E(x). If I would not freeze the second pass of the Encoder, that would be similar as doing the classic mistake in GAN training, where you update the discriminator as well in the generator training step of the GAN training.
The Problem is, that I can not simply zero_grad() or freeze the parameters of the encoder, because then the encoder does not get any gradients at all.
IIUC just doing the second pass inside torch.no_grad() context would suffice
This does not work. If I do that I don’t get any gradients. I think you can not backprop through a graph if part of it was calculated inside torch.no_grad(),i.e. the network is blind to anything that happens after the no_grad() section. I think this does not work for the same reason than you do not get gradients if you write your loss function inside torch.no_grad().
You could try chaining two backwards together:
The first time you compute
E, you take its output
y = E(x) and detach
yd = y.detach().requires_grad_() while making sure detached output requires grad.
When you use the detached output to compute
z = E(yd) again, keep grad enabled to keep generating the backward graph.
Freeze the parameters of E, then call
gy = z.grad(yd) to get gradients of the final loss w.r.t. detached y.
Finally you can unfreeze E and call
edit: you don’t even actually need to freeze/unfreeze the parameters of E since we’re doing
Wouldn’t something like this work:
Assuming L1 = Loss(x, x_hat)
And L2 = Loss(z, z_hat)
1. Compute L2 and backward
2. Zero encoder gradients
3. Compute L1 and backward
4. Update gradients
Probably a cleaner solution would be using two different optimizers, one just for the decoder (L2) and one for the whole model (L1).
But your z_hat = E(x_hat) is the final segment, so why is this a problem? To apply a loss like distance(z_hat, z), z_hat gradient is not needed.
If I do inside torch.no_grad() z_hat=E(x_hat) and then distance(z_hat,z.detach()).backward(), no gradients are produced. What I meant here
is, that if you would further process z_hat outside torch.no_grad() then of course you can have gradients again for these modules. In that sense applying torch.no_grad() at some position is like clipping the gradient flow at that position in the graph.
Hm, even if you compute dDistanceLoss(E(x_hat),z.detach()) / dx_hat for your second loss, that may conflict with the reconstruction loss. Using non-detached z for second loss seems more promising to me… Anyway, manipulations suggested above may do what you want, or temporarily disable .requires_grad for all nn.Parameters.
temporally disable .requires_grad works! I have not found a better solution than that so far. I am detaching z in distance(z_hat,z.detach()), because I don’t want the encoder to adapt explicitly to the “bad” encoding z_hat.