Sorry, that’s a mouthful. Let me try to make it clearer. Suppose I have:
class MyOP(torch.autograd.Function):
@staticmethod
def forward(ctx, input0, input1, input2):
...
return output
@staticmethod
def backward(ctx, grad_output):
...
return grad_input0, grad_input1, grad_input2
If then I do,
out = MyOP.apply(input0, input1, input2)
grad = torch.autograd.grad(out.sum(), input2)
can I somehow know in backward() that I only need grad_input2 so that I can avoid computing grad_input0 and grad_input1 and just return None in their place?
As a hacky way of doing it, I could maybe do
to_be_diff_wrt = [False, False, True]
out = MyOP.apply(input0, input1, input2, to_be_diff_wrt_to)
class MyOP(torch.autograd.Function):
@staticmethod
def forward(ctx, input0, input1, input2, to_be_diff_wrt):
ctx.save_for_backward(to_be_diff_wrt)
...
return output
but I’m wondering if there’s a better way.