Error converting external data on GPU to Tensor in C++

I am attempting to build a cpp extension to return a pytorch tensor in python from already allocated GPU memory in C++ as below.

#include <torch/torch.h>
#include <cuda.h>
#include <cuda_runtime.h>

#define CU_RETURN_NOT_OK(STMT)                                                \
  do {                                                                        \
    CUresult ret = (STMT);                                                    \
    if (ret != CUDA_SUCCESS) {                                                \
      cout << "Cuda Driver API call in " << __FILE__ << " at line " << __LINE__ \
         << " failed with code " << ret << ": " << #STMT << endl;                     \
    }                                                                         \
  } while (0)

using namespace at;
using namespace std;

Tensor load() {
  CUdevice handle;
  CU_RETURN_NOT_OK(cuDeviceGet(&handle, 0));
  CUcontext context_;
  CU_RETURN_NOT_OK(cuCtxCreate(&context_, 0, handle));
  CUdeviceptr data_done;
  CU_RETURN_NOT_OK(cuMemAlloc(&data_done, 16*sizeof(float)));
  auto f = torch::CUDA(kFloat).tensorFromBlob(reinterpret_cast<void*>(data_done), {16});
  cout << f << endl;
  cuCtxDestroy(context_);
  return f;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("load", &load, "load data");
}

However, when I attempt to run load, I get the following error:

In [3]: f = pyt.load()
THCudaCheck FAIL file=/home/devbox/projects/pytorch/aten/src/THC/generic/THCStorage.c line=150 error=49 : incompatible driver context
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-ae3941c99de0> in <module>()
----> 1 f = pyt.load()

RuntimeError: cuda runtime error (49) : incompatible driver context at /home/devbox/projects/pytorch/aten/src/THC/generic/THCStorage.c:150

The memory on the GPU does appear to have successfully been allocated as well, so the issue seems to only be in tensorFromBlob.

tensorFromBlob does appear to be working if I feed it a data pointer for an already existing CUDA Tensor as the code below returns a CUDA tensor in python:

Tensor load() {
  auto g = torch::CUDA(kFloat).rand({16});
  auto f = torch::CUDA(kFloat).tensorFromBlob(reinterpret_cast<void*>(g.data_ptr()), {16});
  return f;
}

As a sanity check, does tensorFromBlob take in a cuda device pointer directly, or does the pointer need to wrapped in some way? I couldn’t find what type is required for the CUDA case in the documentation.

If it does take in the pointer directly, what could be causing this issue?

As for the system used (just in case), pytorch is built from source from the latest master. The cuda version is 9.1.85, and the driver version for the GTX 970 is 390.30.

2 Likes