Loss.backward() error while using with custom autograd function

Hi, I have a custom autograd.Function for correlation layer (that implements cost volume construction on cuda) in addition to the neural net that contains forward with its loss components. When I call loss.backward() there is an error:

RuntimeError: function CorrelationFunctionBackward returned an incorrect number of gradients (expected 8, got 2)

which is, as I figured out, related to the autograd.Function. This is how it looks like:


class CorrelationFunction(Function):

    @staticmethod
    def forward(ctx, input1, input2,
            pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
        # print("in CorrelationFunction forward")
        ctx.pad_size = pad_size
        ctx.kernel_size = kernel_size
        ctx.max_displacement = max_displacement
        ctx.stride1 = stride1
        ctx.stride2 = stride2
        ctx.corr_multiply = corr_multiply
        
        ctx.save_for_backward(input1, input2)

        with torch.cuda.device_of(input1):
            rbot1 = input1.new()
            rbot2 = input2.new()
            output = input1.new()

            correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 
                ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input1, input2 = ctx.saved_tensors

        with torch.cuda.device_of(input1):
            rbot1 = input1.new()
            rbot2 = input2.new()

            grad_input1 = input1.new()
            grad_input2 = input2.new()

            correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
                ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)

        return grad_input1, grad_input2

What do I need to change in the backward? Before I made it work for the forward pass after reimplementing CorrelationFunction for pytorch 1.13. Thanks in advance!

Yes, you need to return the gradient in your backward for all inputs in your forward. If some arguments do not have valid grads, return None for these.

1 Like

Thanks, the issue was solved after I added None in backwards return for all inputs that were network parameters.