@tom, I think that is really a nice approach. I’d like to explain it a bit:
In PyTorch autograd usually automatically computes the gradients of all operations, as long as requires_grad
is set to True
. If you however need operations that are not natively supported by PyTorch’s autograd, you can manually define the function and how to compute its gradients. Therefore autograd is by default turned off in forward
and backward
of subclasses of torch.autograd.Function
. To turn it back on manually, tom used
with torch.enable_grad():
Within this block gradients are 1) automatically calculated and 2) passed backwards through the graph. You want 1), but I think 2) is not a good idea within forward
, because you are expected to explicitly pass the gradient backwards in backward
.
To prevent the gradient from automatically flowing backward, you need to detach the input from the graph (a_.detach()
). I guess from then on you could let the gradient be calculated implicitly by returning res
(instead of res.detach()
) and get the gradient by gr = a.grad
(perhaps you would have to set retain_graph=True
before calculating res
), but the more explicit way is to also detach the result and calculate the gradient explicitly with
gr, = torch.autograd.grad(res, a, grad_out, retain_graph=True)
Concerning retain_graph
I am a bit puzzled, too.
Apart from that, I think this is a good approach. I cannot think of big impacts on other aspects, so this should be safe. The only thing I can think of is, that it forces the calculation of gradients (within this function) even when running the defined function in a
with torch.no_grad():
block. If the computation of the gradients of your function are complex or require to store lots of intermediate gradients, that might cause GPU memory or runtime issues.
One thing you could do, is to check if the input (a_
) requires gradients and use torch.no_grad()
instead of torch.enable_grad()
or skip the requires_grad_()
part to prevent the calculation of some unnecessary gradients.