Hi,
This is a silly example of how to trigger reentrant autograd.
It’s a bit silly but for example if the gradient that you want to return for your function is the gradient for another function (that you compute with the autograd):
import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable
class MyFunc(Function):
@staticmethod
def forward(ctx, inp):
return (2*inp).sum()
@staticmethod
@once_differentiable
def backward(ctx, inp):
# Let's compute the backward by using autograd
tmp_inp = inp.detach().requires_grad_()
with torch.enable_grad():
tmp_out = (2*tmp_inp).sum()
tmp_out.backward()
return tmp_out.detach()
a = torch.rand(10, 10, requires_grad=True)
b = MyFunc.apply(a)
b.backward()
The only advantage of this silly implementation for example is that no intermediary results are saved for the (2*inp).sum() function and so saves a bit of memory. The checkpointing feature uses similar trick.