CUDA illegal memory access for cuda extension in multiprocessing

I wrote a simple CUDA extension to multiply a tensor in-place on GPU, it works fine on single process but in multiprocessing mode it gets the illegal memory access error. I’ve tried compute-sanitizer as well but didn’t get any more meaningful info. Here’s a minimal reproducible code package, simply run python install && python

Here are the main snippets for a quick preview:

CUDA implementation:

__global__ void raw_mul_cuda_forward_kernel(float* x) {
  const int r = blockIdx.x;
  const int c = blockIdx.y;
  const int d = threadIdx.x;
  const int D = blockDim.x;
  x[r * D * D + c * D + d] *= 2.f;

void mul_cuda_forward(torch::Tensor x) {
  const auto H = x.size(0);
  const auto W = x.size(1);
  const auto D = x.size(2);
  const int threads = D;
  const dim3 blocks(H, W);
  raw_mul_cuda_forward_kernel<<<blocks, threads>>>(<float>());

In Python:

import torch
import inplacemul_cuda
from tqdm import tqdm

def error_callback(error):
	print(error, flush=True)

def process(rank):
    device = torch.device(f"cuda:{rank}")
    x = torch.rand(32, 32, 32).to(device)
    gold = x * 2
    assert torch.abs(gold - x).sum() == 0

if __name__ == '__main__':
    assert torch.cuda.is_available()

    # sequential test
    for _ in tqdm(range(10)):
    print('Single-process test passed')

    # multiproc test
    import torch.multiprocessing as mp
    ctx = mp.get_context("spawn")
    N = 1000
    pool = ctx.Pool(processes=2)
    pbar = tqdm(total=N)
    rets = []
    num_gpus = torch.cuda.device_count()
    gpu = 0
    for _ in range(N):
        ret = pool.apply_async(process, args=(gpu,), 
            callback=lambda _: pbar.update(1), 
        gpu = (gpu + 1) % num_gpus
    for ret in rets:

You might need to add a device guard via:

at::cuda::OptionalCUDAGuard device_guard(id)

Yes that works, thanks for the quick reply. For reference: I added #include <c10/cuda/CUDAGuard.h> at the top, and at::cuda::OptionalCUDAGuard device_guard(x.device()); before the kernel launch.