When writing a torch.autograd.Function, is there some way for the .backward() to know with respect to which input are the gradients being requested?

Sorry, that’s a mouthful. Let me try to make it clearer. Suppose I have:

class MyOP(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input0, input1, input2):
        ...
        return output

    @staticmethod
    def backward(ctx, grad_output):
        ...
        return grad_input0, grad_input1, grad_input2

If then I do,

out = MyOP.apply(input0, input1, input2)
grad = torch.autograd.grad(out.sum(), input2)

can I somehow know in backward() that I only need grad_input2 so that I can avoid computing grad_input0 and grad_input1 and just return None in their place?

As a hacky way of doing it, I could maybe do

to_be_diff_wrt = [False, False, True] 
out = MyOP.apply(input0, input1, input2, to_be_diff_wrt_to)

class MyOP(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input0, input1, input2, to_be_diff_wrt):
        ctx.save_for_backward(to_be_diff_wrt)
        ...
        return output

but I’m wondering if there’s a better way.

You can use ctx.needs_input_grad inside backward() to check which inputs actually require gradients. It returns a tuple of booleans denoting whether each of the inputs had requires_grad == True when the op was called (it won’t update if this changes after the operation is called).

I don’t think this works in my case because all three inputs require grads and therefore that bool tensor will be [True, True, True]. However, sometimes, I only need to use one or two of the grads and discard the rest, and therefore I would like to avoid computing them because they happen to be quite costly.

Interesting question! If all inputs require gradients, ctx.needs_input_grad won’t help in filtering. Maybe a cleaner approach could be passing a flag or metadata through ctx during the forward pass, just like you suggested—it’s not hacky at all, actually quite practical in this case.