CUDA tensors on multiprocessing queue

I am working on a problem where multiple workers send CUDA tensors to a shared queue that is read by the main process. While the code works great with CPU tensors (i.e. the tensors sent by the workers are retrieved correctly by the main process), I am finding that when the workers send CUDA tensors through the shared queue, the tensor values read by the main process are often garbage values.

Related discussion: Invalid device pointer using multiprocessing with CUDA.

e.g. the following minimal code reproduces this issue. It works fine for CPU tensors, but repeatedly reads only the last few tensors that were sent for CUDA.

import torch
import torch.multiprocessing as mp


DEVICE = "cuda"
N = 10

done = mp.Event()


def proc(queue):
    t = torch.tensor(0., device=DEVICE)
    n = 0
    while n < N:
        t = t+1
        print("sent:", t)
        queue.put(t)
        n += 1
    done.wait()


if __name__ == "__main__":
    ctx = mp.get_context("spawn")
    queue = ctx.Queue()
    p = ctx.Process(target=proc, args=(queue,))
    p.daemon = True
    p.start()
    for _ in range(N):
        print("recv:", queue.get())
    done.set()

which prints out:

  $ python examples/ex2.py 
sent: tensor(1., device='cuda:0')
sent: tensor(2., device='cuda:0')
sent: tensor(3., device='cuda:0')
sent: tensor(4., device='cuda:0')
sent: tensor(5., device='cuda:0')
sent: tensor(6., device='cuda:0')
sent: tensor(7., device='cuda:0')
sent: tensor(8., device='cuda:0')
sent: tensor(9., device='cuda:0')
sent: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')

Based on the discussion in Invalid device pointer using multiprocessing with CUDA and @colesbury’s suggestion in Using torch.Tensor over multiprocessing.Queue + Process fails, I suspected that the issue is that we need to hold the reference to the CUDA tensor in the worker process until it is read by the main process. To test this, I appended the tensors to a temporary list in the worker and then it worked fine!

This is however quite wasteful because the workers produce a large number of tensors and it is not feasible to hold them in memory. I was wondering what is the best practice to deal with such a use case. I can try to store the tensors in a limited size queue and hope that the last one has had sufficient time to be read by the main process, but that seems too fragile. cc. @smth, @colesbury

1 Like

Unfortunately, I don’t know of a good way to do this. Trying to manage the lifetimes of CUDA tensors across processes is complicated. I try to only share CUDA tensors that have the same lifetime as the program (like model weights for example).

1 Like

Thanks @colesbury. This will be nice to have to do parallel MCMC sampling. For the time being, I’m thinking of using events to periodically clear off tensors from the shared memory - lets see how far that gets us. :slight_smile:

1 Like

Interested to hear if you made any progress here @neerajprad @colesbury?

I’ve just tried running you minimal example with PyTorch 1.11.0 and was not able to reproduce the issue (tensors are received in the correct order). Also tried adding a time.sleep statement in the proc loop and a gc.collect() in both proc and recv loop to trigger garbage collection of shared CUDA tensors, but it did not happen – the order was still correct.

This kind of data prefetching is appealing because it does not suffer from the limitations/challenges of .to(non_blocking=True). That is, requiring pin_memory=True (which has it’s own set of potential issues) and the inconvenience of not being “allowed” to log values that originate from CUDA tensors to the terminal between forward-backward passes/batches within sacrificing the effect of non_blocking.