CUDA Synchronize using Aten.h for a PyTorch CUDA C++ extension

Hi all,

I am trying to implement a mixed C++ CUDA Torch extension. My understanding is that I can only use ATen.h for this. (e.g. Torch.h cannot be used). The issue is that I wan to synchronize CUDA kernels from CPU, so I am trying to do the equivalent to cudaDeviceSynchronize() in CUDA , I cannot find any similar function within at::cuda class. Could anyone help me out please ? Am I missing something ?

I think you could try the following:

CUDAStream stream = getCurrentCUDAStream();
AT_CUDA_CHECK(cudaStreamSynchronize(stream));

Thanks Ptrblck !, I’wll try it and post my findings.

I am running my code using CUDAStream to synchronize my CUDA threads from CPU. Since this is a CUDA C++ PyTorch extension, I imported the module to python and the call my extension. I am having this error at the line this is implemented: AT_CUDA_CHECK(cudaStreamSynchronize(stream));

RuntimeError: CUDA error: CUDA driver version is insufficient for CUDA runtime version…

Any advice will be highly appreciate it.

Do you have multiple CUDA versions installed as described here?

I think there was an issue with the version of pytorch I had installed which was compiled for a different version of CUDA I have installed. Having a clean pytroch installation with the appropriate CUDA version fixed the original problem. Thanks a lot for that !.

Now, I am facing other issue.

When running my extension on python , the next error is thrown out:

RuntimeError: CUDA error: an illegal memory access was encountered (bspline_cuda at bspline_cuda_kernel.cu:252)

This is the line where there is this check:

AT_CUDA_CHECK(cudaStreamSynchronize(stream));

The code segment that is giving me issues is this:

vector<at::Tensor> bspline_cuda(at::Tensor x, at::Tensor y)
{

    const int n= NCP;
    const int m_mm = NCP;
    const int n_mm = NCP;
    const int k_mm=1;

    dim3 dimGrid((k_mm + BLOCK_SIZE - 1) / BLOCK_SIZE, (m_mm + BLOCK_SIZE - 1) / BLOCK_SIZE);
    dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);

    auto a  = at::zeros_like(x);
    auto b  = at::zeros_like(x);
    auto c  = at::zeros_like(x);
    auto d  = at::zeros_like(x);
    
    auto delta  = at::zeros(NCP-1,  at::kFloat);
    auto Delta  = at::zeros(NCP-1,  at::kFloat);
    auto M      = at::zeros({NCP,NCP},  at::kFloat);
    auto D      = at::zeros(NCP,  at::kFloat);

    auto TensorM_inv = at::zeros({NCP,NCP});

    for (int j =0; j<n; j++)
        for (int i=0;i<n; i++)
            TensorM_inv[i][j] = M_inv[i][j];

 AT_DISPATCH_FLOATING_TYPES( a.type(), 
    "NaturalCubicSplinePart1",
       ([&] 
        {
        NaturalCubicSplinePart1<scalar_t><<< N/TPB, TPB >>>(
                   a.data<scalar_t>(), 
                   delta.data<scalar_t>(), 
                    Delta.data<scalar_t>() , 
                     x.data<scalar_t>(), 
                  y.data<scalar_t>() 
                     );
                }
              )
        );

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaStreamSynchronize(stream));

Thanks a lot in advance for the help.

Hi all, finally I got what was wrong with my code:

a) I had a very old NVIDIA driver and CUDA 8.0. So, I update my system to have a clean installation:

NVIDIA-SMI 410.104 Driver Version: 410.104 CUDA Version: 10.0

b) * AT_CUDA_CHECK(cudaStreamSynchronize(stream));* was throwing out an error because I was not careful with the types of my tensors. To solve this problem I made sure all types matched. Below the section of my code with types correctly set.

vector<at::Tensor> bspline_cuda(at::Tensor x, at::Tensor y)
{

    const int n= NCP;
    const int m_mm = NCP;
    const int n_mm = NCP;
    const int k_mm=1;

    dim3 dimGrid((k_mm + BLOCK_SIZE - 1) / BLOCK_SIZE, (m_mm + BLOCK_SIZE - 1) / BLOCK_SIZE);
    dim3 dimBlock(BLOCK_SIZE, BLOCK_SIZE);

    auto a  = at::zeros_like(x);
    auto b  = at::zeros_like(x);
    auto c  = at::zeros_like(x);
    auto d  = at::zeros_like(x);
    
    auto delta  = at::zeros(NCP-1,  x.type());
    auto Delta  = at::zeros(NCP-1,  x.type());
    auto M      = at::zeros({NCP,NCP},  x.type());
    auto D      = at::zeros(NCP,  x.type());

    auto TensorM_inv = at::zeros({NCP,NCP}, x.type());

    for (int j =0; j<n; j++)
        for (int i=0;i<n; i++)
            TensorM_inv[i][j] = M_inv[i][j];


AT_DISPATCH_ALL_TYPES( x.type(), 
                            "NaturalCubicSplinePart1",
                            ([&] 
                                {
                                 NaturalCubicSplinePart1<scalar_t><<< N/TPB, TPB >>>(
                                                                 a.data<scalar_t>(), 
                                                                 delta.data<scalar_t>(), 
                                                                 Delta.data<scalar_t>() , 
                                                                 x.data<scalar_t>(), 
                                                                 y.data<scalar_t>() 
                                                                 );
                                }
                            )
                        );


at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaStreamSynchronize(stream));

Thanks a lot, my issue is fixed. :slight_smile:

1 Like

Hello, how to use these two, and where should they be placed?

You can call these two lines of code in your C++ code when you want to synchronize the stream.


I added it this way and it doesn’t work, it takes the same time as if I didn’t add it.

  • Transferring tensors from CUDA to CPU takes a lot of time

Could you post a minimal code snippet which shows the slow data transfer?

Hi Jorge,
I don’t know how to install c++ extension’s header file. Can you guys tell me? Thanks a lot :slight_smile:

I am having a hard time compiling these lines.

I have in my code

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaStreamSynchronize(stream));

The compiler could not identify at::cuda, neither AT_CUDA_CHECK. Then, I tried to use the include

#include <ATen/cuda/CUDAContext.h>

But the compiler also complained

fatal error: cublas_v2.h: No such file or directory


Does anyone know what am I missing?

Thanks in advance.

In my case, the following snippet worked for me.

#include <torch/extension.h>

#include <chrono>
#include <cuda.h>
#include <cuda_runtime.h>

// From https://stackoverflow.com/a/14038590/2313889
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess) 
   {
      fprintf(stderr,"GPU Assert Error: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}

std::vector<at::Tensor> multiplication_complex_cuda(at::Tensor x, at::Tensor h) {

    const int THREADS = 1024;
    const int B = x.size(0);
    const int F = h.size(0);
    const int C = x.size(1);
    const int H = x.size(2);
    const int W = x.size(3);
    const int PLANE_SIZE = H*W;

    const auto Z = (H*W + THREADS - 1)/THREADS;
    const dim3 GRID_SIZE(B, F, Z);

    auto out = torch::zeros(
        {B, F, H, W, 2},
        torch::TensorOptions().device(x.device().type(), x.device().index())
        );

    auto start = std::chrono::high_resolution_clock::now();

    AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "multiplication_complex_cuda",
    ([&] {
        multiplication_traditional_cuda_kernel<scalar_t><<<GRID_SIZE, THREADS>>>(
        x.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
        h.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
        out.packed_accessor32<scalar_t, 5, torch::RestrictPtrTraits>(),
        THREADS, C, W, PLANE_SIZE);
    }));

    cudaDeviceSynchronize();

    auto stop = std::chrono::high_resolution_clock::now();
    auto duration = at::zeros({1});
    duration[0] = std::chrono::duration_cast<std::chrono::nanoseconds>(stop - start).count();

    return {out, duration};
}


Though I got around the problem. I would still appreciate if someone could point what was wrong before. As you can see, I avoided the ATen functions to handle Stream. It would be handy to know how to use them.

Thank you all.