Custom CUDA Kernel returns Zero or Illegal Memory Access Error

Hi everyone,
following the C++/CUDA extension tutorial on the pytorch website and having a look at the linked source code I have created my own CUDA kernel which does not do something useful, but is done as a learning project.
My issue is that in this simple example I either get the created zero matrix as a result in Python or, if after the kernel call cudaDeviceSynchronize(); is added, an illegal memory access error (Code 700).

The kernel code is the following, it copies the data from a const 3d input matrix into a 3d output matrix:

template <typename scalar_t>
__global__ void reduce_cuda_kernel(
    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> matrix,
    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> output
)
{
    const int y = blockIdx.y * blockDim.y + threadIdx.y;
    const int x = blockIdx.x * blockDim.x + threadIdx.x;
    
    if ((y < output.size(0)) && (x < output.size(1)))
    {
        printf("Hello from CUDA If.\n");
        for(int i=0;i<matrix.size(2);i++)
        {
            output[y][x][i] = matrix[y][x][i];
        }
    }
}

The kernel is launched in the following code, which creates a zero tensor from the input matrix “matrix” with zero_like and launches the kernel with calculated amount of blocks, each with pre set 32x32 thread grids:

torch::Tensor reduce_cuda(torch::Tensor matrix)
{
    // Define output like our input matrix
    auto output = torch::zeros_like(matrix);
    // Fixed block size of 32x32 threads
    const dim3 threads(32, 32, 1);
    // Calculate grid size based on matrix
    const auto H = matrix.size(0);
    const auto W = matrix.size(1);
    const dim3 blocks(ceilf(H/32.0), ceilf(W/32.0), 1);
    
    // Call the kernel
    AT_DISPATCH_FLOATING_TYPES(output.scalar_type(), "reduce_cuda", ([&] {
            reduce_cuda_kernel<scalar_t><<<blocks, threads>>>(
            matrix.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),
            output.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>());
    }));
    cudaDeviceSynchronize();
    printf("Cuda Error: %d \n", cudaGetLastError());
    return output;
}

The entire project and the setup.py install output are available here in this Github Gist.

As described above, if I remove cudaDeviceSynchronize(); from reduce_cuda(…) in reduce_cuda_kernel.cu I get in Python the zero matrix created with

auto output = torch::zeros_like(matrix);

as a result from the kernel call. The matrix has the same shape, dtype and is on the same CUDA device, but it is just zero. No values are copied towards the output, but CUDA does not throw an error (Error Code 0 is printed) and I get all “Hello from CUDA If” printouts.
If I add cudaDeviceSynchronize();, an error 700 is thrown with illegal memory access when I try to print out the result in Python.
I do not find any error, as other code, e.g. the mentioned example above, are doing more or less the same thing when dispatching the kernel.
Any help is appreciated, thanks in advance.

Greetings,
Jan

I believe that the true behavior is what happens with the device synchronize—I suspect that the same failure would surface if you tried to e.g., print the output the tensor as it would also synchronize to copy the values back to the host. Another check you could do is run your executable with compute-sanitizer e.g., compute-sanitizer ./myexec which could hopefully give you a slightly more pinpointed location of where the IMA happens.

To start, I’d try a few sanity checks like removing memory accesses for either matrix, output, or both, and consider passing the tensors directly rather than a packed accessor to see if any of those changes the behavior. It might also help to see the failure occurs beyond a specific index, or if even accessing the very start of each tensor triggers the IMA.