Detach and re-attach variable

I am trying to accomplish something like the following:

z = encoder1(x, y)
theta = encoder2(x, y)
predMean = decoder(x, z, theta)

Where x and y are my data

In the loss function, I would like to introduce a term that is the derivative of predMean wrt x, considering z and theta constant.

However I think that

predMeanGrad =  grad(
                outputs=predMean,
                inputs=x,
                create_graph=True,
                retain_graph=True,
                grad_outputs=torch.ones(predMean.shape))[0]

Is actually applying the chain rule so that the derivative contains also the terms dz/dx, dtheta/dx.

My initial guess was to detach z and theta and compute the derivative of predMean; however I would need to reattach them when back-propagating the loss.

Are there more sophisticated solutions to your knowledge?

You can’t reattach, but you can

  • z_d = z.detach().requires_grad_() and similarly theta_d.
  • pred_mean = decoder(x, z_d, theta_d)
  • compute x_grad, z_d_grad, theta_d_grad as grad of pred_mean w.r.t. x, z_d, theta_d.
  • theta.backward(theta_d_grad), analogously for z.

You end up with x.grad, y.grad gradients of x, y considering the decoder input of x fixed (so only through theta and z), x_grad the gradient w.r.t. x considering z, and theta fixed.

Best regards

Thomas

Hi Thomas,

thank you so much!!