C++/Cuda extension with multiple GPUs

I followed the tutorial to create custom c++/cuda extensions https://pytorch.org/tutorials/advanced/cpp_extension.html.

What I have works locally (only 1 pytorch capable GPU), but I have problems running it on our cluster with 4 GPUs per node:

  • when I start learning, I see in nvidia-smi, that 4 python processes use GPU 0. The rest of the GPUs have one python process. I used to see only one process on each GPU before I implemented the extension.
  • after some time I get “RuntimeError: CUDA error: an illegal memory access was encountered” errors on GPUs 1, 2 and 3 (but not 0)
  • I started the python script using CUDA_LAUNCH_BLOCKING=1 and could trace the error to my own CUDA extension.
  • I have the feeling that wrong data is used, because the system is not learning. But I have no proof for that, it might be an unrelated issue.

I set the CUDA device by putting all pytorch tensors onto the right device using “cuda:2” etc device names.

The extension is a computation including a gradient computation for the backward pass. I have to create some new matrices in the c++ part and I do that by

torch::Tensor sum = torch::zeros({n.batch, n.layers, n.xes}, torch::dtype(xes.dtype()).device(xes.device()));

I hope, that I put them onto the right GPU by doing that, and I think that I’m creating all tensors using such code (the code is short, I have one in the forward and 2 in the backward pass).

My interpretation of the issue is, that the kernel runs on GPU0 anyway, and that there might be some synchronisation issues or similar later.

I would like to use cudaSetDevice from the cuda driver, but I wouldn’t know where to get the correct gpu id from pytorch. Besides there might be a nicer way to do it. Is there anything that sticks into your eyes? If not, I’ll create a minimal example.

Thanks, Adam

I guess that the custom kernel tries to access data in a tensor from another GPU and could thus fail.
A minimal (executable) code snippet would be great for debugging. :slight_smile:

Thanks for the answer!

I’m pretty sure that all the tensors going into the C++ functions are on the same GPU:
data going in is something like:

a:Tensor = something on gpu x
b = a.where(a[:, :, :, 0] > 0, torch.zeros(1, device=a.device))
result = cpp_extension_function(a, b)

I’m also not moving data between GPUs. I call the python scripts with the GPU that they should use and that’s it.

Now I had another idea: Instead of putting the tensors on the right GPU by using pytorch’s cuda:x, I switched to NVIDIA’s environment variable CUDA_VISIBLE_DEVICES=x. It works as expected. But I’d prefer to have the C++/CUDA functinos behave as expected also when the visible devices are not limited.

If I really don’t have to do anything special to start the kernel on the correct GPU (it finds out by itself), then I’m puzzled and I’ll work on the minimal example on monday :slight_smile:

Not exactly a minimal example, but I think it is sufficient to reproduce:

I cloned https://github.com/pytorch/extension-cpp (which is from the tutorial i linked in the question).

I opened check.py and changed line 86 from

    device = torch.device("cuda")

to

    device = torch.device("cuda:2")

I verified the cuda flag is not used in any other place to set the device of a tensor.

when I ran “python check.py --cuda forward” on the cluster with 4 GPUs I saw similar behaviour: python was using GPU 0 and 2, and the check did not pass the cuda test. running the original version used only GPU 0 and passed both tests.

(on a side note, “python check.py --cuda backward” didn’t pass even without code change).

Thanks for the pointer! I can reproduce the issue and will debug it.

1 Like

I was wrong and a device guard is needed for custom CUDA extensions.
Here is the diff to make the example executable on a non-default device:

diff --git a/cuda/lltm_cuda.cpp b/cuda/lltm_cuda.cpp
index 2434776..62c9628 100644
--- a/cuda/lltm_cuda.cpp
+++ b/cuda/lltm_cuda.cpp
@@ -1,5 +1,5 @@
 #include <torch/extension.h>
-
+#include <c10/cuda/CUDAGuard.h>
 #include <vector>

 // CUDA forward declarations
@@ -40,7 +40,7 @@ std::vector<torch::Tensor> lltm_forward(
   CHECK_INPUT(bias);
   CHECK_INPUT(old_h);
   CHECK_INPUT(old_cell);
-
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
   return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
 }

@@ -62,6 +62,7 @@ std::vector<torch::Tensor> lltm_backward(
   CHECK_INPUT(X);
   CHECK_INPUT(gate_weights);
   CHECK_INPUT(weights);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(weights));

   return lltm_cuda_backward(
       grad_h,