Using torch.Tensor over multiprocessing.Queue + Process fails

Hi,

Context

I have a simple algorithm that distributes a number of tasks across a list of Process, then the results of the workers is sent back using a Queue. I was previously using numpy to do this kind of job.

Problem

To be more consistent with my code, I decided to use only torch tensors, unfortunately I think transfering torch.Tensor over Queue is not possible, maybe because of Pickle or something. I get this kind of error when calling the get() method to retrieve the result from my Queue.

    worker_result = done_queue.get()
  File "/home/ganaye/deps/miniconda3/lib/python3.5/multiprocessing/queues.py", line 113, in get
    return ForkingPickler.loads(res)
  File "/home/ganaye/deps/miniconda3/lib/python3.5/site-packages/torch/multiprocessing/reductions.py", line 70, in rebuild_storage_fd
    fd = df.detach()
  File "/home/ganaye/deps/miniconda3/lib/python3.5/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/ganaye/deps/miniconda3/lib/python3.5/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
  File "/home/ganaye/deps/miniconda3/lib/python3.5/multiprocessing/connection.py", line 493, in Client
    answer_challenge(c, authkey)
  File "/home/ganaye/deps/miniconda3/lib/python3.5/multiprocessing/connection.py", line 732, in answer_challenge
    message = connection.recv_bytes(256)         # reject large message
  File "/home/ganaye/deps/miniconda3/lib/python3.5/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/ganaye/deps/miniconda3/lib/python3.5/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/ganaye/deps/miniconda3/lib/python3.5/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer

Conclusion

After rapidly switching back to numpy and sending random numpy arrays, I am convinced the problem is from using torch tensor. One can reproduce the error with this tested code:

Code

import multiprocessing as mp
import torch

def extractor_worker(done_queue):
    done_queue.put(torch.Tensor(10,10))

producers = []
done_queue = mp.Queue()
for i in range(0, 1):
    process = mp.Process(target=extractor_worker,
                         args=(done_queue,))
    process.start()
    producers.append(process)

result_arrays = []
nb_ended_workers = 0
while nb_ended_workers != 1:
    worker_result = done_queue.get()
    if worker_result is None:
        nb_ended_workers += 1
    else:
        result_arrays.append(worker_result)

Surprisingly it seems that writing is ok, but reading the object throws an error.
If any of you has a workaround to use torch tensor over Queue ! I can switch back to numpy if this is really impossible.

Thanks !

1 Like

Your background process needs to be alive when the main process reads the tensor.

Here’s a small modification to your example:

import multiprocessing as mp
import torch

done = mp.Event()

def extractor_worker(done_queue):
    done_queue.put(torch.Tensor(10,10))
    done_queue.put(None)
    done.wait()

producers = []
done_queue = mp.Queue()
for i in range(0, 1):
    process = mp.Process(target=extractor_worker,
                         args=(done_queue,))
    process.start()
    producers.append(process)

result_arrays = []
nb_ended_workers = 0
while nb_ended_workers != 1:
    worker_result = done_queue.get()
    if worker_result is None:
        nb_ended_workers += 1
    else:
        result_arrays.append(worker_result)
done.set()
7 Likes

This is an unfortunate result of how Python pickling handles sending file descriptors. (We send tensors via shared memory instead of writing the values to the queue). The steps are roughly:

  1. Background process sends token mp.Queue
  2. When the main process reads the token, it opens a unix socket to the background process
  3. The background process sends the file descriptor via the unix socket
5 Likes

Thank you @colesbury, this is an interesting implementation detail.

From what you just said, I guess the previous solution (numpy) worked because the tensor is sent by value rather than using a file descriptor ? Otherwise I don’t get it, I have used the numpy solution for months now.

can’t wait to try this.

2 Likes

Does sending tensors run faster than sending (and having to convert to/from) numpy arrays?

@ethancaballero my first thought was that it would be nearly the same cost, because they share the same storage. So I kept numpy as the main solution for storing an image array.
However we have to keep in mind that the final object will be a torch tensor, so I compared what would be the difference of using numpy or torch, it turns out the first solution took on average 3.3 seconds to build a 512 batch and the second solution (torch) took less than 0.005 seconds. It was measured after few iterations to make sure my 6 workers were really efficient.

This speedup is the reason I decided to switch to torch tensor only. As I said, in the end you will only need a torch tensor, no value to use numpy just as a storage. I think the conversions numpy -> torch are the cause for this slowdown.

On the other hand I was pretty satisfied with the memory footprint of numpy, I hope it will stay constant with torch.

@trypag I think a simpler solution is to replace mp.Queue() with mp.Manager().Queue()
I haven’t yet timed it against done = mp.Event() + done.wait() + done.set() to see if the speed changes.

1 Like

For my application, the number of transfers is quite limited, I will stay with the solution proposed. However with more consuming tasks like distributed computing, it might be useful to explore new solutions.

I had a related question that I posted here - CUDA tensors on multiprocessing queue.

In our application, we have a bunch of workers that are putting CUDA tensors onto a shared queue that is read by the main process. It seems that the workers need to keep the CUDA tensors in memory until the main process has read these tensors. There are two issues here - one, is that the main process actually runs fine without throwing any errors but the tensors read are often garbage values. It will be great if this would trigger an exception. Secondly, it does look like this constraint forces us to have some kind of communication also going backwards - from the main process to the worker processes. e.g. we could store the tensors in a temporary queue in the workers which would be cleared periodically based on information from the main process. Is this the best way to handle such a use case?

FWIW Another solution is to use mp.JoinableQueue, using queue.task_done() on the consumer and queue.join() on the producer.

1 Like

Hi @eacousineau, I’m trying to use mp.JoinableQueue for one producer process, which keeps putting items to the queue. And another consumer process keeps getting item from the queue. For the two processes, does queue.join() and queue.task_done() only need to be set once? or multiple times for each get() and put() call? It would be great if you could give a simple example. Thank you.

Hi @Yi_Zhang, it will be the latter:

For each get() used to fetch a task, a subsequent call to task_done() tells the queue that the processing on the task is complete.

This is the toy example I had used to check it for myself:

Feel free to modify it to test for multi-input multi-output queue processing.

1 Like

Using Manager().Queue solved the problem for me.

ctx = mp.get_context("spawn")
manager = ctx.Manager()
  
data_feed: Queue = manager.Queue(10000)