I’m trying to implement a custom autograd.Function, where one of the arguments for the forward pass does not need a gradient (let’s say an index or a string or something like that). To illustrated I just modified the example from the tutorial as follows:
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input, second):
print(second) # <- second nondifferentiable argument
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input # as a "hack"(?) you could just add None as a second output
Now I get the error
RuntimeError: function MyReLUBackward returned an incorrect number of gradients (expected 2, got 1)
Is there a way to define that the second argument doesn’t need a gradient?
I noticed that you can just use return grad_input, None
in the backward()
function, but that seems just like a hack. Or is this the way you’re supposed to do it?