I am implementing a basic sharpener with one parameter just to test the CUDA extension and get it working for further work. The python class is defined as so.
class NASFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, A):
output = nas_cuda.forward(input, A)
variables = output
ctx.save_for_backward(*variables)
return output
@staticmethod
def backward(ctx, grad_output):
output = nas_cuda.backward(grad_output[0], ctx.saved_variables[1])
return output, ctx.saved_variables[1]
class NAS(nn.Module):
def __init__(self):
super(NAS, self).__init__()
self.A = torch.nn.Parameter(torch.ones(1) * 0.5).cuda()
def forward(self, input):
return NASFunction.apply(input, self.A)
I then define a forward pass and a dummy backward pass. I cannot invoke the dummy backward pass in order to test it and implement it further as it doesn’t appear as a grad_fn hence it can never be applied. The cuda implementation is below.
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
template <typename scalar_t>
__global__ void nas_cuda_forward_kernel(
torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> input,
torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> output,
torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> Ap
){
int i = blockIdx.x * blockDim.x + threadIdx.x;
int j = blockIdx.y * blockDim.y + threadIdx.y;
if (i < 1) i = 1;
if (i > 254) i = 254;
if (j < 1) j = 1;
if (j > 254) j = 254;
// a b c 0 A 0
// d e f . A 1 A
// g h i 0 A 0
float norm_factor = (1 + 4 * 0.5);
const auto b_r = input[0][0][i-1][j];
const auto b_g = input[0][1][i-1][j];
const auto b_b = input[0][2][i-1][j];
const auto d_r = input[0][0][i][j-1];
const auto d_g = input[0][1][i][j-1];
const auto d_b = input[0][2][i][j-1];
const auto e_r = input[0][0][i][j];
const auto e_g = input[0][1][i][j];
const auto e_b = input[0][2][i][j];
const auto f_r = input[0][0][i][j+1];
const auto f_g = input[0][1][i][j+1];
const auto f_b = input[0][2][i][j+1];
const auto h_r = input[0][0][i+1][j];
const auto h_g = input[0][1][i+1][j];
const auto h_b = input[0][2][i+1][j];
float Ap_ = Ap[0];
output[0][0][i][j] = (e_r + b_r * Ap_ + d_r * Ap_ + f_r * Ap_ + h_r * Ap_) / norm_factor;
output[0][1][i][j] = (e_g + b_g * Ap_ + d_g * Ap_ + f_g * Ap_ + h_g * Ap_) / norm_factor;
output[0][2][i][j] = (e_b + b_b * Ap_ + d_b * Ap_ + f_b * Ap_ + h_b * Ap_) / norm_factor;
}
std::vector<torch::Tensor> nas_cuda_forward(
torch::Tensor input,
torch::Tensor A
){
auto output = torch::zeros_like(input);
const auto batch_size = input.size(0);
const auto channels = input.size(1);
const auto h = input.size(2);
const auto w = input.size(3);
const dim3 blockDim(16, 16);
const dim3 gridDim(16, 16);
AT_DISPATCH_FLOATING_TYPES(input.type(), "nas_cuda_forward", ([&] {
nas_cuda_forward_kernel<scalar_t><<<gridDim, blockDim>>>(
input.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
output.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
A.packed_accessor<scalar_t, 1, torch::RestrictPtrTraits, size_t>()
);
}));
return {output, A};
}
std::vector<torch::Tensor> nas_cuda_backward(
torch::Tensor grad_input,
torch::Tensor A
){
A = A* 0.5;
return {grad_input, A};
}
Then if I execute the following python to test it
import torch
from cuda.nas import NAS
x = torch.ones(1,3,256,256, requires_grad=True).cuda()
nas = NAS()
y = nas(x)
Which provides the following output
[[0., 0., 0., ..., 0., 0., 0.],
[0., 1., 1., ..., 1., 1., 0.],
[0., 1., 1., ..., 1., 1., 0.],
...,
[0., 1., 1., ..., 1., 1., 0.],
[0., 1., 1., ..., 1., 1., 0.],
[0., 0., 0., ..., 0., 0., 0.]],
[[0., 0., 0., ..., 0., 0., 0.],
[0., 1., 1., ..., 1., 1., 0.],
[0., 1., 1., ..., 1., 1., 0.],
...,
[0., 1., 1., ..., 1., 1., 0.],
[0., 1., 1., ..., 1., 1., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]], device='cuda:0')
And the A parameter is returned as
tensor([0.5000], device='cuda:0', grad_fn=<ToCopyBackward0>)
There is no grad_fn in the output tensor x hence i cannot use Autograd to propagate any information back.
What am I missing here? I am failing to see why output of nas has no grad_fn