Cuda extension with multiple gpus

Hey everyone!

I have a problem with my custom cuda kernel with multiple gpus.
I wrote the following kernel to essentially do torch.logical_not(torch.isclose(X, scalar, rtol=radius)). My kernel code looks like this:

(file custom_ops_kernel.cu)

#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAFunctions.h>

// C++ PROJECT INCLUDES


template<typename scalar_t>
__global__
void
cuda_notnear_center_kernel(const scalar_t* __restrict__ x,
                           bool* __restrict__ z,
                           const double lthreshold,
                           const double rthreshold,
                           size_t N)
{
    const int index = blockIdx.x * blockDim.x + threadIdx.x;

    if (index < N)
    {
        z[index] = x[index] < lthreshold || x[index] > rthreshold;
    }
}



torch::Tensor
cuda_notnear_center(torch::Tensor x,
                    torch::Tensor out,
                    double center,
                    double radius)
{
    const double lthreshold = center - radius;
    const double rthreshold = center + radius;

    // auto z = torch::empty_like(x, torch::kBool);
    const int num_threads = 1024;
    const int num_blocks = (x.size(0) / num_threads) + 1;
    // printf("num_blocks: %i\n", num_blocks);

    const auto current_device = c10::cuda::current_device();
    c10::cuda::set_device(x.get_device());
    AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "cuda_notnear_center", ([&] {
        cuda_notnear_center_kernel<scalar_t><<<num_blocks, num_threads>>>(
            x.data_ptr<scalar_t>(),
            out.data_ptr<bool>(),
            lthreshold,
            rthreshold,
            x.size(0));
    }));
    c10::cuda::set_device(current_device);

    // printf("Error: %s", cudaGetLastError());

    return out;
}

and is managed by a cpp file custom_ops.cc which the relevant parts look like:

#include <torch/extension.h>
#include <vector>


// C++ PROJECT INCLUDES

// taken from https://pytorch.org/tutorials/advanced/cpp_extension.html
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_SAME_SIZE(x, y) TORCH_CHECK(torch::is_same_size(x, y), #x " must have save size as " #y)
#define CHECK_TYPE(x, t) TORCH_CHECK(x.scalar_type() == t, #x " must have same type as " #t)

torch::Tensor
cuda_notnear_center(torch::Tensor x,
                    torch::Tensor out,
                    double center,
                    double radius);

torch::Tensor&
notnear_center_out(torch::Tensor& out,
                   const torch::Tensor& x,
                   const double center,
                   const double radius)
{
    CHECK_INPUT(x);
    CHECK_INPUT(out);
    CHECK_SAME_SIZE(x, out);
    CHECK_TYPE(out, torch::kBool);

    cuda_notnear_center(x.view({-1}), out.view({-1}), center, radius);
    return out;
}


torch::Tensor
notnear_center(const torch::Tensor& x,
               const double center,
               const double radius)
{
    CHECK_INPUT(x);
    torch::Tensor out = torch::empty_like(x, torch::kBool);
    cuda_notnear_center(x.view({-1}), out.view({-1}), center, radius);
    return out;
}

My problem is that my machine has multiple gpus. When I launch this kernel on a gpu other than gpu0 I end up allocated with memory on gpu0 from this kernel. I have done some testing with my code and have narrowed it down to the kernel. I thought that I could set the device with c10::cuda::set_device() but is this the proper way?

Here is the output of nvidia-smi on my 4-gpu machine. The main processes are running on gpu1 and gpu3 but both processes are allocating memory on gpu0:

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    128024      C   python3                          1481MiB |
|    0   N/A  N/A    129010      C   python3                          1481MiB |
|    1   N/A  N/A    129010      C   python3                         10157MiB |
|    3   N/A  N/A    128024      C   python3                         10195MiB |
+-----------------------------------------------------------------------------+

Are you looking for a torch::CUDAGuard RAII variable?

Best regards

Thomas

I’m not sure. This is my first attempt at using a kernel (that I wrote) on a gpu that isnt gpu0, so I don’t know much at all about specifying the device.

For instance, if I were to use a CUDAGuard, does it go in the kernel itself or is my placement correct?

Take these with a grain of salt as from someone who does single GPU more often than multi, but

  • I’d use the guard instead of set device,
  • I’d set the device based on the input tensors, not the local state of PyTorch,
  • I’d do it right at the top of the function taking the tensor (it also affects new tensors that you might create).

Best regards

Thomas

Thanks a ton, I really appreciate it :slight_smile:

update. I tried using c10::cuda::CUDAGuard guard(x.get_device()) but I still have the same problem. I did however move c10::cuda::set_device(x.get_device()) to the top of the function which worked for some reason.

To be more clear, I have the following lines at the top of my function:

torch::Tensor
cuda_notnear_center(torch::Tensor x,
                    torch::Tensor out,
                    double center,
                    double radius)
{
    const auto current_device = c10::cuda::current_device(); // will use to replace device after call
    c10::cuda::set_device(x.get_device()); // set device to be device of tensor X

    ... // function body

    c10::cuda::set_device(current_device); // put state back after call 
    return out;
}

This seemed to work, but I have no idea why:

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    1   N/A  N/A    104618      C   python3                         10155MiB |
|    3   N/A  N/A    103802      C   python3                         10195MiB |
+-----------------------------------------------------------------------------+

1 Like

Glad you figured it out!

1 Like