Propagation trough 2 identical networks but do not accumulate gradients w.r.t the second pass

Hello, I think I have in principle a relatively simple problem that seems to be fundamentally impossible to solve in pytorch.
Let’s assume we propagate a batch X two times trough the same network E and then calculate a loss. I am then interested in the gradient with respect to the parameters in E but only with respect to the first pass. In other words, the second pass trough E together with the loss function could be considered as the effective total loss function for the first pass of E.
This is also equivalent to the following situation:
We clone (deep copy) E and call it H. Then we propagate X one time trough E and then trough H, that is, out=H(E(X)) and calculate a loss on its output. Finally, we calculate the gradient w.r.t E only.
But obviously I don’t want to clone a model for each batch, so how can I do this more efficiently?

Thanks for helping me out here!

Ps: Here is a use case in practice. Assume you have an autoencoder consisting of an encoder E: x->z and decoder (or generator) G:z->x. Now suppose you don’t only want to impose a reconstruction loss s.t. x=G(E(x))=x_hat but also a latent consistency s.t. z=E(x)=E(G(E(x)))=E(x_hat)=z_hat. As you can see, in order to calculate z_hat you propagate two times trough the encoder E, but you don’t want the encoder to learn how to map x_hat close to z (that would result in a degenerated solution). What you want is to teach the encoder-decoder pair to reconstruct an x_hat, such that for a frozen encoder E the latent variable z_hat=E(x_hat) is close to z=E(x). If I would not freeze the second pass of the Encoder, that would be similar as doing the classic mistake in GAN training, where you update the discriminator as well in the generator training step of the GAN training.
The Problem is, that I can not simply zero_grad() or freeze the parameters of the encoder, because then the encoder does not get any gradients at all.

IIUC just doing the second pass inside torch.no_grad() context would suffice

1 Like

This does not work. If I do that I don’t get any gradients. I think you can not backprop through a graph if part of it was calculated inside torch.no_grad(),i.e. the network is blind to anything that happens after the no_grad() section. I think this does not work for the same reason than you do not get gradients if you write your loss function inside torch.no_grad().

You could try chaining two backwards together:
The first time you compute `E`, you take its output `y = E(x)` and detach `yd = y.detach().requires_grad_()` while making sure detached output requires grad.
When you use the detached output to compute `z = E(yd)` again, keep grad enabled to keep generating the backward graph.
Freeze the parameters of E, then call `gy = z.grad(yd)` to get gradients of the final loss w.r.t. detached y.
Finally you can unfreeze E and call `y.backward(grad_tensors=gy)`.

edit: you don’t even actually need to freeze/unfreeze the parameters of E since we’re doing `.grad`

Wouldn’t something like this work:

Assuming L1 = Loss(x, x_hat)
And L2 = Loss(z, z_hat)

``````1. Compute L2 and backward
2. Zero encoder gradients
3. Compute L1 and backward