Custom C++-CUDA Layer save_for_backward_equivalent

Hello, i have implemented a custom layer for pytorch and want to speed it up by using CUDA.
I was following the tutorial given here https://pytorch.org/tutorials/advanced/cpp_extension.html.
But it is not mentioned how to save tensors in the forward pass ctx.save_for_backward in CUDA.
I would really appreciate if someone could explain how to do this if possible or give me suggestion on how to do it in another way.
Here is the code for python implementation:

class NeighborFill(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y, z, neighbors):
        dt = TimestampMillisec64()
        nx = torch.zeros((1, x.shape[1], neighbors.shape[1] * neighbors.shape[2])).cuda()
        ny = torch.zeros((1, y.shape[1], neighbors.shape[1] * neighbors.shape[2])).cuda()
        nz = torch.zeros((1, z.shape[1], neighbors.shape[1] * neighbors.shape[2])).cuda()
        for i in range(0, neighbors.shape[1]):
            for j in range(0, neighbors.shape[2]):
                index = neighbors[0][i][j] - 1
                if (index >= 0):
                    nx[0, :, i * neighbors.shape[2] + j] = x[0, :, index]
                    ny[0, :, i * neighbors.shape[2] + j] = y[0, :, index]
                    nz[0, :, i * neighbors.shape[2] + j] = z[0, :, index]
        ctx.save_for_backward(x, y, z, neighbors)
        ctx.mark_non_differentiable(neighbors)
        return nx, ny, nz
    
    @staticmethod
    def backward(ctx, gradoutx, gradouty, gradoutz):
        dt = TimestampMillisec64()
        x, y, z, neighbors = ctx.saved_tensors
        gradinpx = torch.zeros((1, x.shape[1], x.shape[2])).cuda()
        gradinpy = torch.zeros((1, y.shape[1], y.shape[2])).cuda()
        gradinpz = torch.zeros((1, z.shape[1], z.shape[2])).cuda()    
        for i in range(0, neighbors.shape[1]):
            for j in range(0, neighbors.shape[2]):
                index = neighbors[0][i][j] - 1
                if (index >= 0):
                    gradinpx[0, :, i] += gradoutx[0, :, i * neighbors.shape[2] + j]
                    gradinpy[0, :, i] += gradouty[0, :, i * neighbors.shape[2] + j]
                    gradinpz[0, :, i] += gradoutz[0, :, i * neighbors.shape[2] + j]
        
        return gradinpx, gradinpy, gradinpz, neighbors

And here is the code for CUDA. I have only the forward for now:

std::vector<torch::Tensor> neighborfill_forward(torch::Tensor x, torch::Tensor y, torch::Tensor z, torch::Tensor neighbors){
    CHECK_INPUT(x);
    CHECK_INPUT(y);
    CHECK_INPUT(z);
    CHECK_CUDA(neighbors);
    
    return neighborfill_cuda_forward(x, y, z, neighbors)
}

std::vector<torch::Tensor> neighborfill_cuda_forward(torch::Tensor x, torch::Tensor y, torch::Tensor z, torch::Tensor neighbors){
    auto batch_size = x.size(0);
    auto state_size = x.size(2);
    int threads = 1024;
    dim3 blocks((state_size + threads - 1) / threads, batch_size);
    torch::Tensor nx = torch::zeros((1, x.size(1), neighbors.size(1) * neighbors.size(2)));
    torch::Tensor nx = torch::zeros((1, y.size(1), neighbors.size(1) * neighbors.size(2)));
    torch::Tensor nx = torch::zeros((1, z.size(1), neighbors.size(1) * neighbors.size(2)));
    
    neighborfill_cuda_forward_kernel<x.type(), neighbors.type()><<<blocks, threads>>>(
    nx.packed_accessor32<nx.type(), 3, torch::RestrictPtrTraits>(), 
    ny.packed_accessor32<ny.type(), 3, torch::RestrictPtrTraits>(), 
    nz.packed_accessor32<nz.type(), 3, torch::RestrictPtrTraits>(), 
    x.packed_accessor32<nx.type(), 3, torch::RestrictPtrTraits>(), 
    y.packed_accessor32<ny.type(), 3, torch::RestrictPtrTraits>(), 
    z.packed_accessor32<nz.type(), 3, torch::RestrictPtrTraits>(), 
    neighbors.packed_accessor32<neighbors.type(), 3, torch::RestrictPtrTraits>(), 
    x.size(2), neighbors.size(2), x.size(1)
    );
    
    return {nx, ny, nz}
}

template <typename x_t, typename neighbors_t>
__global__ void fill_neighbors_kernel(torch::PackedTensorAccessor32<x_t,3,torch::RestrictPtrTraits> nx, torch::PackedTensorAccessor32<x_t,3,torch::RestrictPtrTraits> ny, torch::PackedTensorAccessor32<x_t,3,torch::RestrictPtrTraits> nz, torch::PackedTensorAccessor32<x_t,3,torch::RestrictPtrTraits> x, torch::PackedTensorAccessor32<x_t,3,torch::RestrictPtrTraits> y, torch::PackedTensorAccessor32<x_t,3,torch::RestrictPtrTraits> z, torch::PackedTensorAccessor32<neighbor_t,3,torch::RestrictPtrTraits> neighbors, int tsize, int nsize, int fmaps){
    int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
    if(index < tsize){ 
        for(int i = 0; i < nsize; i++){
            int nind = neighbors[0][index][i] - 1;   
            if(nind >= 0){
                for(int j = 0 ; j < fmaps; j++){
                    nx[0][j][index * nsize + i] = x[0][j][nind];//TODO INDEXING
                    ny[0][j][index * nsize + i]  = y[0][j][nind];
                    nz[0][j][index * nsize + i]  = z[0][j][nind];
                }
            }
        }
    }

In the tutorial you link, the cpp function is then used inside an autograd.Function. So you can save them in the autograd.Function that wraps your cpp implementation as you do in your current implementation.

Thank you. I didn’t realize i needed to write a wrapper.