Constructing PyTorch's CUDA tensor from C++ with image data already on GPU

Hi!

I need to use PyTorch model from existing C++/OpenCV based application. All images are processed with OpenCV’s CUDA modules. Currently, I have to copy all the data back to CPU and use boost::python converters to make NumPy array from it, which I can use to construct PyTorch Tensor object. It works, but introduces major slowdowns (of course).

Is there a way to prepare data for PyTorch backend in such a way, that all the data would be always on GPU?

I know, that PyTorch shares its backend with original Torch (to some extent), so is this project of any use?

Looking forward for any clues

you can pass in the pointer of an existing pytorch torch.cuda.FloatTensor into C++ side and issue a cudaMemCpyAsync from OpenCV buffer to torch.cuda.FloatTensor buffer.
The pointer of the Tensor can be obtained with:

a = torch.cuda.FloatTensor(3, 640, 480) # new buffer of 3x640x480
pointer = a.data_ptr()
# pass this pointer into C/C++ side of your OpenCV pipeline
# copy the OpenCV buffer into CUDA

Lastly, when doing CUDA with datasets (i.e. multiprocessing with CUDA), you have to use python3 and spawn_server: http://pytorch.org/docs/notes/multiprocessing.html#sharing-cuda-tensors

1 Like

Whoa. Somehow I’ve overlooked the data_ptr method in docs.

I just wrote a basic example illustrating your approach.

Thank you!

2 Likes