Issue with CUDA backwards kernel

I am trying to write a custom backwards function using the CUDAExtension support in PyTorch, but the kernel always gets gradient values as 0 (except for the first location), while printing out the gradient before dispatching it to the kernel prints the correct gradient values.

Here are my CUDA kernels:

#include <iostream>

#include <torch/torch.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/CUDAHooks.h>
#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

template <typename T>
__global__ void CoolForward(const int nthreads, const T* input, T* output, 
                            const int batch_size, const int channels, 
                            const int height, const int width)
{
    int index = (blockIdx.x * blockDim.x) + threadIdx.x;
    if (index < nthreads){
        int w = index % width;
        int h = (index / width) % height;
        int c = (index / width / height) % channels;
        int n = index / width / height / channels;

        output[index] = input[index] + 1;
    }
}

template <typename T>
__global__ void CoolBackward(const int nthreads, const T* grad, 
                             const int batch_size, const int channels, 
                             const int height, const int width, T* grad_input)
{
    int index = (blockIdx.x * blockDim.x) + threadIdx.x;
    if (index < nthreads){

        int w = index % width;
        int h = (index / width) % height;
        int c = (index / width / height) % channels;
        int n = index / width / height / channels;

        // print the gradient values
        printf("index=%d, n=%d, c=%d, h=%d, w=%d, value=%f \n", index, n, c, h, w, grad[index]);

        // copy the gradient values
        grad_input[index] = grad[index];
    }
}

This is my cpp-extension code:

at::Tensor cool_forward_cuda(const at::Tensor &input)
{
    int n = input.size(0);
    int c = input.size(1);
    int h = input.size(2);
    int w = input.size(3);

    auto output_size = n * c * h * w;

    dim3 grid(1);
    dim3 block(512);

    at::Tensor output = at::zeros({n, c, h, w}, input.type());
    at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

    AT_DISPATCH_ALL_TYPES(input.type(), "cool_forward", ([&] {
        CoolForward<scalar_t><<<grid, block, 0, stream>>>(
             output_size,
             input.data<scalar_t>(),
             output.data<scalar_t>(),
             n,
             c,
             h,
             w);
      }));

    THCudaCheck(cudaGetLastError());
    return output;
}

at::Tensor cool_backward_cuda(const at::Tensor &grad,
                              const int batch_size,
                              const int channels,
                              const int height,
                              const int width)
{

    std::cout << "gradient:\n " <<  grad << std::endl;  // this prints the gradients correctly
    at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.type());

    dim3 grid(1);
    dim3 block(10);   // block size is 10 to help with debugging
    at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

    AT_DISPATCH_FLOATING_TYPES(grad.type(), "cool_backward", ([&] {
        CoolBackward<scalar_t><<<grid, block, 0, stream>>>(
             grad.numel(),
             grad.data<scalar_t>(),
             batch_size,
             channels,
             height,
             width,
             grad_input.data<scalar_t>());
      }));
    THCudaCheck(cudaGetLastError());
    return grad_input;
}

And after a bunch of boilerplate to get the extension running, here is my python code:

import torch
from cool_test import _C
from torch.autograd import Function
from torch import nn


class CoolFunc(Function):
    @staticmethod
    def forward(ctx, x):
        y = _C.cool_forward(x)
        return y

    @staticmethod
    def backward(ctx, grad):
        dx = _C.cool_backward(grad,
                              grad.shape[0], grad.shape[1],
                              grad.shape[2], grad.shape[3])
        return dx


class Cool(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return CoolFunc.apply(x)


x = torch.rand(1, 1, 5, 5).cuda()
x.requires_grad = True
m = Cool()
print("input=\n", x)

y = m(x)
print("y=", y)
s = y.sum()
print("s=", s)
s.backward()

print("x grad\n", x.grad)

which gives me this output:

input=
 tensor([[[[0.8135, 0.5903, 0.4413, 0.4109, 0.9244],
          [0.1199, 0.0319, 0.9599, 0.0336, 0.3685],
          [0.7604, 0.5010, 0.7218, 0.2321, 0.1330],
          [0.0986, 0.7200, 0.3598, 0.8679, 0.2629],
          [0.1547, 0.2007, 0.0335, 0.4450, 0.9741]]]],
       device='cuda:0', requires_grad=True)
y= tensor([[[[1.8135, 1.5903, 1.4413, 1.4109, 1.9244],
          [1.1199, 1.0319, 1.9599, 1.0336, 1.3685],
          [1.7604, 1.5010, 1.7218, 1.2321, 1.1330],
          [1.0986, 1.7200, 1.3598, 1.8679, 1.2629],
          [1.1547, 1.2007, 1.0335, 1.4450, 1.9741]]]],
       device='cuda:0', grad_fn=<CoolFuncBackward>)
s= tensor(36.1598, device='cuda:0', grad_fn=<SumBackward0>)
gradient:
 (1,1,.,.) = 
  1  1  1  1  1
  1  1  1  1  1
  1  1  1  1  1
  1  1  1  1  1
  1  1  1  1  1
[ Variable[CUDAFloatType]{1,1,5,5} ]
x grad
index=0, n=0, c=0, h=0, w=0, value=1.000000 
index=1, n=0, c=0, h=0, w=1, value=0.000000 
index=2, n=0, c=0, h=0, w=2, value=0.000000 
index=3, n=0, c=0, h=0, w=3, value=0.000000 
index=4, n=0, c=0, h=0, w=4, value=0.000000 
index=5, n=0, c=0, h=1, w=0, value=0.000000 
index=6, n=0, c=0, h=1, w=1, value=0.000000 
index=7, n=0, c=0, h=1, w=2, value=0.000000 
index=8, n=0, c=0, h=1, w=3, value=0.000000 
index=9, n=0, c=0, h=1, w=4, value=0.000000 
 tensor([[[[1., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]], device='cuda:0')

As you can see, the last tensor output is just the gradients copied over, yet it is 1 only in a single place and zero everywhere else. Any help?

I am pretty much a noob at CUDA so this may just boil down to bad memory access, but if someone could elaborate on how I should be doing this, I would be really thankful.

@goldsborough any help here, please?

Hi,

I am not sure but I think the problem is that you ignore strides when you index your tensors in the kernel.
Does changing you CoolFunc that way helps ?

    @staticmethod
    def backward(ctx, grad):
        grad = grad.clone()
        dx = _C.cool_backward(grad,
                              grad.shape[0], grad.shape[1],
                              grad.shape[2], grad.shape[3])
        return dx

Huh that seemed to have worked!
Can you please explain some more why seemingly just creating a duplicate of the tensor fixes this?

I printed out the stride values of the gradients before and after the clone and I get

  • before: (0, 0, 0, 0)
  • after: (25, 25, 5, 1)

Is this potentially a bug?

Hi,

That is expected. The backward of sum, for more efficiency return a tensor that is expand.
That means that in practice, there is a single memory element allocated for it, and the stride is 0 for all dimensions as you can see.
This means that the first element in the first dimension is at position 0*0=0 but the second element is at position 1*0=0 as well. The full formula is element_index * stride = position_in_memory.

In pytorch, you have last dimensions that are contiguous, and so after cloning and getting a full tensor. The stride of the last dimension is 1 as expected: you need to move of 1 to get the next value of that dimension. For the previous to last dimension, the stride is 5 (which is the size of the last dimension) as there are 5 elements between each consecutive elements on that dimension.

Note that pytorch memory access also have a storage_offset element that means that the element [0,0,0,0] of your tensor might not be the first element in the pointer for data.

If you’re not familiar with these things, it might be simpler to keep the clone and then you can assume that you have a contiguous tensor in the cuda code and don’t need to pass storage_offset and stride as arguments.

I see what you mean. Backwards of sum simply expands the tensor to the forward input tensor’s shape, but still has only one memory location allotted.

I guess this means that when reading the sum tensor in backwards, I have to read it taking into account the stride values? (aka all the reads point back to the single memory element?)

To index an Aten tensor properly, you need to take into account it’s size, stride and storageOffset (not sure about the exact name of the last in Aten).
At the moment, your cuda kernel only considers the size and assume storageOffset=0 and stride is the one of a contiguous tensor.

You can either adapt your kernel to take all of them into account, or add a .clone() before your kernel to make sure that your assumptions will be verified.

1 Like