@tom, I think that is really a nice approach. I’d like to explain it a bit:
In PyTorch autograd usually automatically computes the gradients of all operations, as long as
requires_grad is set to
True. If you however need operations that are not natively supported by PyTorch’s autograd, you can manually define the function and how to compute its gradients. Therefore autograd is by default turned off in
backward of subclasses of
torch.autograd.Function. To turn it back on manually, tom used
Within this block gradients are 1) automatically calculated and 2) passed backwards through the graph. You want 1), but I think 2) is not a good idea within
forward, because you are expected to explicitly pass the gradient backwards in
To prevent the gradient from automatically flowing backward, you need to detach the input from the graph (
a_.detach()). I guess from then on you could let the gradient be calculated implicitly by returning
res (instead of
res.detach()) and get the gradient by
gr = a.grad (perhaps you would have to set
retain_graph=True before calculating
res), but the more explicit way is to also detach the result and calculate the gradient explicitly with
gr, = torch.autograd.grad(res, a, grad_out, retain_graph=True)
retain_graph I am a bit puzzled, too.
Apart from that, I think this is a good approach. I cannot think of big impacts on other aspects, so this should be safe. The only thing I can think of is, that it forces the calculation of gradients (within this function) even when running the defined function in a
block. If the computation of the gradients of your function are complex or require to store lots of intermediate gradients, that might cause GPU memory or runtime issues.
One thing you could do, is to check if the input (
a_) requires gradients and use
torch.no_grad() instead of
torch.enable_grad() or skip the
requires_grad_() part to prevent the calculation of some unnecessary gradients.