What is the scenario of 'Reentrant backwards' in pytorch source code?

Hi,

This is a silly example of how to trigger reentrant autograd.
It’s a bit silly but for example if the gradient that you want to return for your function is the gradient for another function (that you compute with the autograd):

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable

class MyFunc(Function):
    @staticmethod
    def forward(ctx, inp):
        return (2*inp).sum()

    @staticmethod
    @once_differentiable
    def backward(ctx, inp):
        # Let's compute the backward by using autograd
        tmp_inp = inp.detach().requires_grad_()
        with torch.enable_grad():
            tmp_out = (2*tmp_inp).sum()
        tmp_out.backward()
        return tmp_out.detach()


a = torch.rand(10, 10, requires_grad=True)

b = MyFunc.apply(a)
b.backward()

The only advantage of this silly implementation for example is that no intermediary results are saved for the (2*inp).sum() function and so saves a bit of memory. The checkpointing feature uses similar trick.

4 Likes