I am working on a problem where multiple workers send CUDA tensors to a shared queue that is read by the main process. While the code works great with CPU tensors (i.e. the tensors sent by the workers are retrieved correctly by the main process), I am finding that when the workers send CUDA tensors through the shared queue, the tensor values read by the main process are often garbage values.
Related discussion: Invalid device pointer using multiprocessing with CUDA.
e.g. the following minimal code reproduces this issue. It works fine for CPU tensors, but repeatedly reads only the last few tensors that were sent for CUDA.
import torch
import torch.multiprocessing as mp
DEVICE = "cuda"
N = 10
done = mp.Event()
def proc(queue):
t = torch.tensor(0., device=DEVICE)
n = 0
while n < N:
t = t+1
print("sent:", t)
queue.put(t)
n += 1
done.wait()
if __name__ == "__main__":
ctx = mp.get_context("spawn")
queue = ctx.Queue()
p = ctx.Process(target=proc, args=(queue,))
p.daemon = True
p.start()
for _ in range(N):
print("recv:", queue.get())
done.set()
which prints out:
$ python examples/ex2.py
sent: tensor(1., device='cuda:0')
sent: tensor(2., device='cuda:0')
sent: tensor(3., device='cuda:0')
sent: tensor(4., device='cuda:0')
sent: tensor(5., device='cuda:0')
sent: tensor(6., device='cuda:0')
sent: tensor(7., device='cuda:0')
sent: tensor(8., device='cuda:0')
sent: tensor(9., device='cuda:0')
sent: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
recv: tensor(9., device='cuda:0')
recv: tensor(10., device='cuda:0')
Based on the discussion in Invalid device pointer using multiprocessing with CUDA and @colesbury’s suggestion in Using torch.Tensor over multiprocessing.Queue + Process fails, I suspected that the issue is that we need to hold the reference to the CUDA tensor in the worker process until it is read by the main process. To test this, I appended the tensors to a temporary list in the worker and then it worked fine!
This is however quite wasteful because the workers produce a large number of tensors and it is not feasible to hold them in memory. I was wondering what is the best practice to deal with such a use case. I can try to store the tensors in a limited size queue and hope that the last one has had sufficient time to be read by the main process, but that seems too fragile. cc. @smth, @colesbury