Is there a diffrence between ctx.needs_input_grad behaviour vs input_tensor.requires_grad?

Hi folks,

Let’s assume we define a custom autograd Function:

class MyFn(torch.autograd.Function):

    def forward(ctx, input_tensor):
        if ctx.ctx.needs_input_grad[0]:  # Is this...
        if input_tensor.requires_grad:  # ...equivalent to this?

For this function, is there a difference from the autograd point of view between ctx.needs_input_grad[0] and input_tensor.requires_grad?
The reason I’m asking is that these two behave differently when running torch.compile on the custom function - ctx.needs_input_grad triggers many more recompiles and needs to be explicitly converted to a tuple due to torch.compile not liking get_set_attritbute - input_tensor.requires_grad seems to work much better with torch.compile.

Thanks a lot :slight_smile:


needs_input_grad is actually available during the backward pass where it will be able to tell you, for the current gradient computation, which of the inputs require gradients. In particular, different calls to autograd.grad() can lead to different value for this field.
On the other hand, input_tensor.requires_grad is a forward time construct that controls if the graph is build so that future call to backward can happen.

Extending PyTorch — PyTorch 2.1 documentation has a lot more details on how to use these features.

Thanks a lot for the explanation, that makes sense :slight_smile: I read the linked documentation before but the requires_grad behaviour wasn’t/isn’t clear to me. The docs state that “Tensor arguments that track history (i.e., with requires_grad=True ) will be converted to ones that don’t track history before the call,”, so I would assume that input_tensor.requires_grad should be False. However, the following prints forward input_tensor.requires_grad True:

class MyFn(torch.autograd.Function):

    def forward(ctx, input_tensor):
        print("forward input_tensor.requires_grad", input_tensor.requires_grad)
        print("forward id(input_tensor)", id(input_tensor))
        return input_tensor * 2

    def backward(ctx, grad):
        return grad * 2

x = torch.ones([8], requires_grad=True)
out = MyFn.apply(x).sum()

Is this just an optimization and I shouldn’t rely on it? If this is intended behaviour, I guess ctx.needs_input_grad[0] and input_tensor.requires_grad are equal as ctx.needs_input_grad also only has forward time information.


Yes, it is an optimization that the input Tensor is the same as the one passed in so the requires_grad field is unchanged. You can look at it but shouldn’t need usually.
And yes also, need_input_grad is set with the requires_grad info during the forward, it will be updated during the backward when necessary so you should query it again there when needed.