How do I write a backward function in a custom c++ extension?

Hi.
I am editing the cpp extension code written in the previous version.
I want to run it on the latest pytorch version.

To be blunt, I’m trying to change this model:

And I am studying through this tutorial:
https://pytorch.org/tutorials/advanced/cpp_extension.html

Below is the python code I wrote. (.cpp and .cu files have changed a bit, but it doesn’t matter to the question.)


import torch
import torch.nn as nn
import correlation_cuda
import math

class CorrelationFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx, input1, input2,
        pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply,
        ):
        ctx.save_for_backward(input1, input2)
        B, C, H, W = input1.shape
        padH = int(H + 2 * pad_size)
        padW = int(W + 2 * pad_size)
        out_channel = int(((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1))
        kernel_radius = (kernel_size - 1) // 2
        border_radius = kernel_radius + max_displacement
        outH = int(math.ceil((padH - 2 * border_radius) / stride1))
        outW = int(math.ceil((padW - 2 * border_radius) / stride1))

        rbot1 = torch.zeros((B, padH, padW, C), dtype=torch.float32, device=input1.device)
        rbot2 = torch.zeros((B, padH, padW, C), dtype=torch.float32, device=input1.device)
        output = torch.zeros((B, out_channel, outH, padW), dtype=torch.float32, device=input1.device)

        correlation_cuda.forward(
            input1, input2, rbot1, rbot2, output, 
            pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply
            )

        return output

    @staticmethod
    def backward(
        ctx, grad_output,
        pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply
        ):
        input1, input2 = ctx.saved_tensors
        B, C, H, W = input1.shape
        padH = H + 2 * pad_size
        padW = W + 2 * pad_size

        rbot1 = torch.zeros((B, padH, padW, C), dtype=torch.float32, device=input1.device)
        rbot2 = torch.zeros((B, padH, padW, C), dtype=torch.float32, device=input1.device)

        grad_input1 = torch.zeros((B, C, H, W), dtype=torch.float32, device=input1.device)
        grad_input2 = torch.zeros((B, C, H, W), dtype=torch.float32, device=input1.device)

        correlation_cuda.backward(
            input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
            pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply
            )

        return grad_input1, grad_input2

class Correlation(nn.modules.module.Module):
    def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
        super(Correlation, self).__init__()
        self.pad_size = pad_size
        self.kernel_size = kernel_size
        self.max_displacement = max_displacement
        self.stride1 = stride1
        self.stride2 = stride2
        self.corr_multiply = corr_multiply

    def forward(self, input1, input2):

        result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)

        return result

    def backward(self, grad_output):

        result = CorrelationFunction.apply(grad_output, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)

        return result


cor = Correlation(pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1).to(device='cuda:0')

input1 = torch.randn((2, 4, 32, 32), dtype=torch.float32, device='cuda:0')
input2 = torch.randn((2, 4, 32, 32), dtype=torch.float32, device='cuda:0')
conv1 = nn.Conv2d(4, 4, 1, 1, padding=1).to(device='cuda:0')
conv2 = nn.Conv2d(4, 4, 1, 1, padding=1).to(device='cuda:0')

input1 = conv1(input1)
input2 = conv1(input2)

output = cor(input1, input2)
print(output.shape)

loss = 2 - output.max()
loss.backward()

Currently, backpropagation does not proceed with my code.
In the tutorial, class LLTM does not have a backward function.
How should I write the backward function?
Thanks for reading.

The custom backward function is needed in the autograd.Function but not in the custom nn.Module, so remove it from the latter and it should work.