I’m trying to implement a custom autograd.Function, where one of the arguments for the forward pass does not need a gradient (let’s say an index or a string or something like that). To illustrated I just modified the example from the tutorial as follows:
class MyReLU(torch.autograd.Function): @staticmethod def forward(ctx, input, second): print(second) # <- second nondifferentiable argument ctx.save_for_backward(input) return input.clamp(min=0) @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input < 0] = 0 return grad_input # as a "hack"(?) you could just add None as a second output
Now I get the error
RuntimeError: function MyReLUBackward returned an incorrect number of gradients (expected 2, got 1)
Is there a way to define that the second argument doesn’t need a gradient?
I noticed that you can just use
return grad_input, None in the
backward() function, but that seems just like a hack. Or is this the way you’re supposed to do it?