Custom CUDA extension backward pass does not appear as grad_fn

I am implementing a basic sharpener with one parameter just to test the CUDA extension and get it working for further work. The python class is defined as so.

class NASFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, A):
        output = nas_cuda.forward(input, A)
        variables = output
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        output = nas_cuda.backward(grad_output[0], ctx.saved_variables[1])
        return output, ctx.saved_variables[1]


class NAS(nn.Module):
        def __init__(self):
            super(NAS, self).__init__()
            self.A = torch.nn.Parameter(torch.ones(1) * 0.5).cuda()

        def forward(self, input):
            return NASFunction.apply(input, self.A)

I then define a forward pass and a dummy backward pass. I cannot invoke the dummy backward pass in order to test it and implement it further as it doesn’t appear as a grad_fn hence it can never be applied. The cuda implementation is below.

#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>


template <typename scalar_t>
__global__ void nas_cuda_forward_kernel(
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> input,
    torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> output,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> Ap
){

  int i = blockIdx.x * blockDim.x + threadIdx.x;
  int j = blockIdx.y * blockDim.y + threadIdx.y;

  if (i < 1) i = 1;
  if (i > 254) i = 254;
  if (j < 1) j = 1;
  if (j > 254) j = 254;

  // a b c   0 A 0 
  // d e f . A 1 A
  // g h i   0 A 0 

  float norm_factor = (1 + 4 * 0.5);

  const auto b_r = input[0][0][i-1][j];
  const auto b_g = input[0][1][i-1][j];
  const auto b_b = input[0][2][i-1][j];

  const auto d_r = input[0][0][i][j-1];
  const auto d_g = input[0][1][i][j-1];
  const auto d_b = input[0][2][i][j-1];
  const auto e_r = input[0][0][i][j];
  const auto e_g = input[0][1][i][j];
  const auto e_b = input[0][2][i][j];
  const auto f_r = input[0][0][i][j+1];
  const auto f_g = input[0][1][i][j+1];
  const auto f_b = input[0][2][i][j+1];

  const auto h_r = input[0][0][i+1][j];
  const auto h_g = input[0][1][i+1][j];
  const auto h_b = input[0][2][i+1][j];

  float Ap_ = Ap[0];

  output[0][0][i][j] = (e_r + b_r * Ap_ + d_r * Ap_ + f_r * Ap_ + h_r * Ap_) / norm_factor;
  output[0][1][i][j] = (e_g + b_g * Ap_ + d_g * Ap_ + f_g * Ap_ + h_g * Ap_) / norm_factor;
  output[0][2][i][j] = (e_b + b_b * Ap_ + d_b * Ap_ + f_b * Ap_ + h_b * Ap_) / norm_factor;

}

std::vector<torch::Tensor> nas_cuda_forward(
  torch::Tensor input,
  torch::Tensor A
){
  auto output = torch::zeros_like(input);
  const auto batch_size = input.size(0);
  const auto channels = input.size(1);
  const auto h = input.size(2);
  const auto w = input.size(3);

  const dim3 blockDim(16, 16);
  const dim3 gridDim(16, 16);

  AT_DISPATCH_FLOATING_TYPES(input.type(), "nas_cuda_forward", ([&] {
    nas_cuda_forward_kernel<scalar_t><<<gridDim, blockDim>>>(
      input.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
      output.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
      A.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>()
      );
  }));

  return {output, A};
}


std::vector<torch::Tensor> nas_cuda_backward(
  torch::Tensor grad_input,
  torch::Tensor A
){
    A = A* 0.5;
    return {grad_input, A};
}

Then if I execute the following python to test it

import torch

from cuda.nas import NAS

x = torch.ones(1,3,256,256, requires_grad=True).cuda()

nas = NAS()

y = nas(x)

Which provides the following output

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 1.,  ..., 1., 1., 0.],
          [0., 1., 1.,  ..., 1., 1., 0.],
          ...,
          [0., 1., 1.,  ..., 1., 1., 0.],
          [0., 1., 1.,  ..., 1., 1., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 1.,  ..., 1., 1., 0.],
          [0., 1., 1.,  ..., 1., 1., 0.],
          ...,
          [0., 1., 1.,  ..., 1., 1., 0.],
          [0., 1., 1.,  ..., 1., 1., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]], device='cuda:0')

And the A parameter is returned as

tensor([0.5000], device='cuda:0', grad_fn=<ToCopyBackward0>)

There is no grad_fn in the output tensor x hence i cannot use Autograd to propagate any information back.

What am I missing here? I am failing to see why output of nas has no grad_fn

You are creating non-leaf tensors by calling differentiable operations on the parameters:

x = torch.ones(1,3,256,256, requires_grad=True).cuda()
...
self.A = torch.nn.Parameter(torch.ones(1) * 0.5).cuda()

Call the cuda() operation (and any other to() operation) on the tensor before wrapping it into an nn.Parameter or set it as an attribute directly:

x = torch.ones(1,3,256,256, requires_grad=True, device="cuda")
...
self.A = torch.nn.Parameter(torch.ones(1, device="cuda") * 0.5)

Hi, thanks so much for the guidance.

I implemented you’re changes and the gradient requirement disappears

x = torch.ones(1,3,256,256, requires_grad=True, device="cuda")
(Pdb) x.requires_grad
True

which when passed through the model shows

(Pdb) nas(x)[0].requires_grad
False

Is it perhaps an issue with the my cuda kernel and it prevents it from being able to track gradient information?

Thanks

Could you create a GitHub repository containing a minimal executable code snippet as well as the build instructions to reproduce the issue, please?
Right now it seems the binding from your extension is missing to the Python API.