Passing Tensor pointers to CUDA code

Hi,

I have a CUDA application which I want to interface with PyTorch with following conditions:

  1. I do not want to create additional dependency of PyTorch in the C++ application for multiple reasons (e.g., CUDA version conflicts).

  2. I want to send/receive GPU data from PyTorch in python to this C++ code. I am thinking of passing tensor.data_ptr() (after making it contiguous) to the C++ code and also pass a pointer to the pre-initialized output array which will be filled by C++ code. Basically following the approach from: Constructing PyTorch's CUDA tensor from C++ with image data already on GPU

Does 2. seem to be the best way to achieve 1? Any pitfalls I should take into consideration?

Thanks,
Ahmed

Your approach sounds valid and you could take a look at this tutorial to see how a custom CUDA extension can be written (and in particular the CUDA kernel).

1 Like

I have written following code (for reference to future visitors) and it seems to be running fine.

m.def("my_func_with_gpu_pointers", [](const long i_ptr, const long j_ptr, const long edge_costs_ptr, const long sol_out_ptr, const int num_edges, const int gpuDeviceID) 
{
  cudaSetDevice(gpuDeviceID);
  const int* const i = reinterpret_cast<const int* const>(i_ptr);
  const int* const j = reinterpret_cast<const int* const>(j_ptr);
  const float* const edge_costs = reinterpret_cast<const float* const>(edge_costs_ptr);
  int* const sol_out_ptr_int = reinterpret_cast<int* const>(sol_out_ptr);
  thrust::device_vector<int> i_thrust(i, i + num_edges);
  thrust::device_vector<int> j_thrust(j, j + num_edges);
  thrust::device_vector<float> costs_thrust(edge_costs, edge_costs + num_edges);
  thrust::device_vector<int> sol;
  sol = my_func(std::move(i_thrust), std::move(j_thrust), std::move(costs_thrust)); 
  thrust::copy(sol.begin(), sol .end(), sol_out_ptr_int);
}
1 Like