CUDA Extension: Illegal memory access was encoutered

I have an Ubuntu 18.04 OS with miniconda3, python 3.7, CUDA 10.1, CuDNN 7.4 and GCC 7.4 installed. I also have compiled PyTorch 1.4 from source.

I wrote a PyTorch C++/CUDA extension code for a specific task that I had using the exact steps mentioned in the tutorial page. My extension looks like this:

// This is the .cpp file
#include <torch/extension.h>
#include <vector>

std::vector<torch::Tensor> zbuffertri_cuda_forward(
                                torch::Tensor s2d,
                                torch::Tensor tri,
                                torch::Tensor visible,
                                int img_size = 224);

// C++ interface

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<torch::Tensor> zbuffertri_forward(torch::Tensor s2d, torch::Tensor tri, torch::Tensor visible, int img_size = 224)
{
    CHECK_INPUT(s2d);
    CHECK_INPUT(tri);
    CHECK_INPUT(visible);

    return zbuffertri_cuda_forward(s2d, tri, visible, img_size);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &zbuffertri_forward, "ZBufferTri Operation (CUDA)");}
// This is the .cu file

# include <torch/types.h>
# include <cuda.h>
# include <cuda_runtime.h>
# include <vector>

#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=false)
{
   if (code != cudaSuccess)
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}

__global__ void convert_to_mask(float *zbuffer, int img_size)
{
    for(int i=blockIdx.x*blockDim.x+threadIdx.x; i<img_size*img_size; i+=blockDim.x*gridDim.x)
    {
        if(zbuffer[i] == -INFINITY)
        {
            zbuffer[i] = 0;
        }
        else
        {
            zbuffer[i] = 1;
        }
    }
}

/* Forward Function */
std::vector<torch::Tensor> zbuffertri_cuda_forward(torch::Tensor s2d, torch::Tensor tri, torch::Tensor visible, int img_size = 224)
{
    auto s2d_data = s2d.data<float>();
    gpuErrchk(cudaPeekAtLastError());
    gpuErrchk( cudaDeviceSynchronize() );

    const int tri_num = tri.size(1);
    const int vertex_num = s2d.size(1);

    auto out = torch::ones({img_size, img_size}, torch::device(s2d.device())) * (tri_num-1);
    auto zbuffer = torch::ones({img_size, img_size}, torch::device(s2d.device())) * (-INFINITY);;

    int N = img_size*img_size;
    const int threads = 256;
    const dim3 blocks((img_size + threads - 1) / threads, img_size);
    int blockSize = 256;
    int numBlocks = (N + blockSize -1) / (2*blockSize);
    std::cout<<numBlocks<<" "<<blockSize<<std::endl;

    AT_DISPATCH_FLOATING_TYPES(zbuffer.type(), "zbuffer_tri_dispatch", ([&] {
        convert_to_mask<scalar_t><<<blocks, threads>>>(zbuffer.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
    }));

    return {out, zbuffer};
}

It gets compiled correctly without any errors. I can also import the compiled library into my code successfully. However, when I run its forward function, it doesn’t work. (There are a few other function calls happening inside the .cu file that I have commented out for now because even the simplest code is not working.) When I use gdb or cuda-gdb to debug into the cuda code by adding a breakpoint at the cuda function, I see an error saying “s2d=<error reading variable: Cannot access memory at address 0x2>, … at zbuffertri_implementation.cu:190”

Single stepping until exit from function _Z23zbuffertri_cuda_forwardN2at6TensorES0_S0_i@plt,
which has no line number information.
zbuffertri_cuda_forward (s2d=<error reading variable: Cannot access memory at address 0x2>, tri=..., visible=..., img_size=32767) at zbuffertri_implementation.cu:190

However, I can run through the function completely without any errors. When I return to the python line where the C++ function was called and try to access the tensors returned from the call, I get the following error:

tri_map_2d, mask_i = zbuffertri.forward(vertex2d_i.contiguous(), self.tri.float(), visible_tri[i].contiguous(), output_size)
(gdb) n
(Pdb) p tri_map_2d
THCudaCheck FAIL file=../aten/src/THC/THCCachingHostAllocator.cpp line=278 error=700 : an illegal memory access was encountered

Been dealing with this issues for days now. Any help would be much appreciated.

Sorry for the late reply! The post was on my ToDo list for some time and I just got a bit of time for debugging.

I’m wondering, why the code compiles fine without throwing any errors, as these points might cause some trouble:

  • convert_to_mask is called with a template argument, while its definition isn’t templated
  • zbufferpacked_accessoris passed with as a 2-dim accessor, which fits the dimensions of the tensor. Howeverconvert_to_masktakes afloat *` as the input argument. I would assume this will throw an error.
  • zbuffer is indexed in a single dimension, although a 2-dim accessor is passed and the tensor wasn’t flattened to a 1-dim tensor

Given that and removing some unused code, this should work:

from torch.utils import cpp_extension

cuda_source = """
template <typename scalar_t>
__global__ void convert_to_mask(
  torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> zbuffer,
  int img_size) {
    int row = blockIdx.y;
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    
    for(int i = index; i < img_size; i += stride) {
        if(zbuffer[row][i] == -INFINITY) {
            zbuffer[row][i] = 0;
        } else {
            zbuffer[row][i] = 1;
        }
    }
}

std::vector<torch::Tensor> zbuffertri_cuda_forward(torch::Tensor s2d, int img_size = 224)
{
    auto out = torch::ones({img_size, img_size}, torch::device(s2d.device())) * 10.;
    auto zbuffer = torch::ones({img_size, img_size}, torch::device(s2d.device())) * (-INFINITY);

    const int threads = 256;
    const dim3 blocks((img_size + threads - 1) / threads, img_size);

    std::cout << "zbuffer size " << zbuffer.sizes();

    AT_DISPATCH_FLOATING_TYPES(zbuffer.type(), "zbuffer_tri_dispatch", ([&] {
        convert_to_mask<scalar_t><<<blocks, threads>>>(
          zbuffer.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
          img_size);
    }));

    return {out, zbuffer};
}
"""

cpp_source = """
    std::vector<torch::Tensor> zbuffertri_cuda_forward(torch::Tensor s2d, int img_size);
"""

module = torch.utils.cpp_extension.load_inline(
    name="cuda_test_extension",
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions="zbuffertri_cuda_forward",
    verbose=True,
)

s2d = torch.randn(1).cuda()
img_size = 100
out, zb = module.zbuffertri_cuda_forward(s2d, img_size)

print(out)
print(zb)

Let me know, if that fits your use case.