How to mark argument as nondifferentiable in a custom autograd Function

I’m trying to implement a custom autograd.Function, where one of the arguments for the forward pass does not need a gradient (let’s say an index or a string or something like that). To illustrated I just modified the example from the tutorial as follows:

class MyReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, second):
        print(second)   # <- second nondifferentiable argument
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input  # as a "hack"(?) you could just add None as a second output

Now I get the error
RuntimeError: function MyReLUBackward returned an incorrect number of gradients (expected 2, got 1)

Is there a way to define that the second argument doesn’t need a gradient?

I noticed that you can just use return grad_input, None in the backward() function, but that seems just like a hack. Or is this the way you’re supposed to do it?

Hi,

No need to do anything specific in the forward. And in the backward, returning None is the right thing to do :slight_smile:

2 Likes

ah cool, then it is simpler than I imagined, thanks a lot for you answer! :slight_smile: