Deadlock/Race Condition when using Multiprocessing Queues with Tensors

Hi guys,

Context
I have a model that uses multiple processes for preprocessing/preparing batches and then trains a model on multiple processes with those batches.

To be more precise I have three tasks:

  1. There is one process that produces data and puts it in a queue.
  2. There are multiple processes which are reading that data, transforming it and then constructing batches. A batch consists of two Tensors. To save performance and following the documentation, I try to reuse Tensors that are already in Shared Memory.
  3. There are then multiple processes taking the already prepared batches and training a model with it. After each batch, they pass the Tensors back to the processes that prepare the batches.

Problem

Sometimes - not always - the last processes are not able to take anything out of the Queue. I’m sure that all processes are still alive as mentioned in other discussions.

It works most of the time, when using only 1 process for each task. But when using >2 for task 2. and 3. it blocks almost always. According to the documentation one should use SimpleQueue if these problems occur. I tried, it didn’t work as well. In my code the processes using SimpleQueue wouldn’t even start and in the minimal example below it works a bit more reliable than before, but fails to restart after interrupting or normal termination.

Code
I made a minimal example out of my code, that reproduces the error. Maybe run it multiple times with different amount of workers.

import time
import torch
import torch.multiprocessing as mp


done = mp.Event()

def first_process(num_data, data_queue, num_workers):
    """ Data producer """
    for i in range(num_data):
        data_queue.put(i)
        time.sleep(0.5)

    for i in range(num_workers):
        data_queue.put(None)
    done.wait()

def second_process(data_queue, batch_queue, return_batch_queue):
    """ Transforms data from first process into batches.
        Alters Tensors that are already in shared memory (return_batch_queue).
    """
    while True:
        i = data_queue.get()

        if i is None:
            batch_queue.put(None)
            break

        item1, item2 = return_batch_queue.get()  # get the tensors from shared memory

        # change them with some dummy task
        item1.zero_()
        item2.zero_()
        item1 += 1 * i
        item2 += 2 * i

        batch_queue.put((item1, item2))  # send the tensors back

    done.wait()

def third_process(batch_queue, return_batch_queue, progress_queue):
    """ Here we would train the model with the batches prepared
        beforehand.
    """
    while True:
        batch = batch_queue.get()  # get the tensors from shared memory

        if batch is None:
            progress_queue.put(None)
            break

        b1, b2 = batch
        print(b1) # Dont do anything special, just print the batch

        return_batch_queue.put(batch)  # send the tensors back

    done.wait()

def main(num_other_processes=1, num_data=50):
    # initialize the queues
    data_queue = mp.Queue(maxsize=2 * num_other_processes)
    batch_queue = mp.Queue()
    return_batch_queue = mp.Queue()
    progress_queue = mp.Queue()

    # Try SimpleQueue?
    # batch_queue = mp.SimpleQueue()
    # return_batch_queue = mp.SimpleQueue()

    # Fill return_batch_queue with some tensors that are then moved around.
    for _ in range(5 * num_other_processes):
        batch = torch.Tensor(5, 5), torch.Tensor(5, 5)
        return_batch_queue.put(batch)

    # Initialize Processes
    workers = [mp.Process(target=first_process,
                           args=(num_data, data_queue, num_other_processes))]

    workers += [mp.Process(target=second_process,
                           args=(data_queue, batch_queue, return_batch_queue))
                for _ in range(num_other_processes)]

    workers += [mp.Process(target=third_process,
                           args=(batch_queue, return_batch_queue, progress_queue))
                for _ in range(num_other_processes)]

    for process in workers:
        process.daemon = True
        process.start()

    # Wait till all third processes are done
    num_ended_workers = 0
    while num_ended_workers < num_other_processes:
        progress = progress_queue.get()

        if progress is None:
            print("Seems like a worker has ended.")
            num_ended_workers += 1

    done.set()  # let all child processes return

    for process in workers:
        process.join()

    print("Finished training.")

if __name__ == '__main__':
    main(4)

Any help is strongly appreciated!
Thanks!