Problems with torch.multiprocess.spawn and SimpleQueue

I am working with pytorch-lightning in an effort to bring objects back to the master process when using DistributedDataParallel. Lightning launches these sub-processes with torch.multiprocessing.spawn().

I ran into some issues, and decided to build a tiny model to try things out. Unfortunately I cannot seem to share a SimpleQueue when using torch.multiprocessing.spawn().

I am on Ubuntu 18.04, python 3.6.8, torch 1.4.

Here is the code:

import torch.multiprocessing as mp

def f(i, q):
    print(f"in f(): {q} {q.empty()}")
    print(f"{q.get()}")


if __name__ == '__main__':
    q = mp.SimpleQueue()
    q.put(['hello'])
    p = mp.spawn(f, (q,))
    print(f"main {q.empty()} {q.get()}")

This results in:

in f(): <multiprocessing.queues.SimpleQueue object at 0x7f97b7b71eb8> False
Traceback (most recent call last):
  File "test.py", line 36, in <module>
    p = mp.spawn(f, (q,))
  File "/home/seth/.local/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not spawn_context.join():
  File "/home/seth/.local/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 107, in join
    (error_index, name)
Exception: process 0 terminated with signal SIGSEGV

The SimpleQueue object appears to be valid - a dir() call shows all the right stuff - but the only call that seems to work is empty(). I haven’t tried all, but the basic ones all result in SIGSEGV. Ouch.

If I do not use mp.spawn() but instead use normal process start and join, of course it works fine. More unusually, it also works fine if I mimic what mp.spawn() does - without pytorch’s SpawnContext() version of join.

import multiprocessing as mp

def f(q):
    msg(f"in f(): {q} {q.empty()}")
    msg(f"{q.get()}")

if __name__ == '__main__':
    mp = mp.get_context('spawn')
    q = mp.SimpleQueue()
    q.put(['hello'])
    p = mp.Process(target=f, args= (q,))
    p.start()
    p.join

Is this a bug - or am I doing something wrong?

1 Like

I have solved this problem - but the solution seems odd to me, and I am concerned I am still doing something wrong.

If I create the q using the context returned from mp.get_context(‘spawn’), and use the torch.multiprocessing module to make the actual spawn call, everything works fine.

import torch.multiprocessing as mp

def f(i, q):
    msg(f"in f(): {q} {q.empty()}")
    msg(f"{q.get()}")

if __name__ == '__main__':
    smp = mp.get_context('spawn')
    q   = smp.SimpleQueue()
    q.put(['hello'])
    p   = mp.spawn(f, (q,))

I tried this because the torch.multiprocessing.spawn() call uses the same result of the same call for it’s start() call. It seems reasonable that the q code has to change depending on sharing mechanism, which will change whether this is forks or threads. Still have to use spawn() from torch.multiprocessing because the context does not have a spawn method.

But I am thoughtful apparently nobody else has this problem. I have not seen anything about this in any on line forums (fora?), and nothing in the docs. Is it really possible I am the only person to run into this?

Again, any commentary from more knowledgable users would be welcome.

s

6 Likes

You are not alone ! thank you very much for your solution.
I encounter the same exception using a Queue in spawn() ; since I handle CUda tensors I have to use this last method. THe exception occurs specifically when I call a method on the Queue inside a process, like a put() ; since the objects I put in the Queue contain Cuda tensors located on a GPU, I thought it might be linked to the memory location, the main process being on the CPU ; but if I move th e tensor to the CPU, it still occurs…
This causes me a great deal of pain since I use multiprocessing to accelerate post-processing after inferences, the code I use beaing very, very slow…
I am trying your solution.
Best

Let me know how it goes …

s

You are not alone, I would like to share the DataSet object among the processes since I do some preprocessing there already (like collecting metadata about the data etc which consumes time)

Did you ever find a solution?

my code just as simple as yours and I get the same error:

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


def example(rank, world_size):
    # create default process group
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8888'
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    outputs = ddp_model(torch.randn(20, 10).to(rank))
    labels = torch.randn(20, 10).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()
    # update parameters
    optimizer.step()

def main():
    # world_size = 2
    world_size = torch.cuda.device_count()
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()
    print('Done\n\a')

err

$ python playground/multiprocessing_playground/ddp_hello_world.py
Traceback (most recent call last):
  File "playground/multiprocessing_playground/ddp_hello_world.py", line 42, in <module>
    main()
  File "playground/multiprocessing_playground/ddp_hello_world.py", line 36, in main
    mp.spawn(example,
  File "/home/miranda9/miniconda3/envs/automl-meta-learning/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/miranda9/miniconda3/envs/automl-meta-learning/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/home/miranda9/miniconda3/envs/automl-meta-learning/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 105, in join
    raise Exception(
Exception: process 0 terminated with signal SIGSEGV