I would like to implement Eq. 6 in https://arxiv.org/pdf/1805.08498.pdf:
d/dphi z = - (d/dz S_phi(z))^{-1} d/dphi S_phi(z),
where z = S_phi^{-1}(eps) for some eps. In essence, letting f1
= S_phi^{-1} and f2
= S_phi, I’d like to compute
d/dphi f2
inside of the backward()
call of a torch.autograd.Function
that implements f1
. My current strategy is to create a new computation graph by cloning phi, i.e.
phi2 = phi.detach().clone().requires_grad_()
F = f2(phi2, z)
F.backward(grad_output) # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
inside of the backward()
call of f1
. However, it appears to me that all of backward()
is wrapped in a with torch.no_grad()
statement, so I get an error when I call F.backward()
. How do I compute gradients inside backward()
calls?
Edit: Added a sentence to the post for clarity