What is the scenario of 'Reentrant backwards' in pytorch source code?

hi, recently I am reading the pytorch C++ source code, in the file engine.cpp we could see the c++ engine instance which runs the backwards.
the author explains the Reentrant backwards by using the scenario of “suppose that you call backward() inside of a worker thread.”
but i would not image such scenario under any condition.

usually the following code would do backwards grad computation. the call stack is like this:y.backward()->THPEngine_run_backward->Engine::execute()->thread_main(). the work thread would do the compute work.

the author says in the work thread during the backwards,we call backward() again,which is the nested.

so in this case, what should my python code look like? I don’t know how to trigger such scenario in python level. even don’t know how to call backward() inside of a worker thread.

maybe Reentrant backwards is designed for extensibility?

x=Variable(torch.Tensor([10]),requires_grad=True)
y=x*5
y.backward()

1 Like

Hi,

This is a silly example of how to trigger reentrant autograd.
It’s a bit silly but for example if the gradient that you want to return for your function is the gradient for another function (that you compute with the autograd):

import torch
from torch.autograd import Function
from torch.autograd.function import once_differentiable

class MyFunc(Function):
    @staticmethod
    def forward(ctx, inp):
        return (2*inp).sum()

    @staticmethod
    @once_differentiable
    def backward(ctx, inp):
        # Let's compute the backward by using autograd
        tmp_inp = inp.detach().requires_grad_()
        with torch.enable_grad():
            tmp_out = (2*tmp_inp).sum()
        tmp_out.backward()
        return tmp_out.detach()


a = torch.rand(10, 10, requires_grad=True)

b = MyFunc.apply(a)
b.backward()

The only advantage of this silly implementation for example is that no intermediary results are saved for the (2*inp).sum() function and so saves a bit of memory. The checkpointing feature uses similar trick.

2 Likes