Computing the gradient of a function f2 inside the backward() call of a different function f1

I would like to implement Eq. 6 in https://arxiv.org/pdf/1805.08498.pdf:

d/dphi z = - (d/dz S_phi(z))^{-1} d/dphi S_phi(z),

where z = S_phi^{-1}(eps) for some eps. In essence, letting f1 = S_phi^{-1} and f2 = S_phi, I’d like to compute

d/dphi f2

inside of the backward() call of a torch.autograd.Function that implements f1. My current strategy is to create a new computation graph by cloning phi, i.e.

phi2 = phi.detach().clone().requires_grad_()
F = f2(phi2, z)
F.backward(grad_output) # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

inside of the backward() call of f1. However, it appears to me that all of backward() is wrapped in a with torch.no_grad() statement, so I get an error when I call F.backward(). How do I compute gradients inside backward() calls?

Edit: Added a sentence to the post for clarity

You can you this API to reenable grad enable_grad — PyTorch 1.11.0 documentation

Yes! This was exactly what I was looking for. Thank you very much!

1 Like