How can I get returns from a function in distributed data parallel?

Hi, I’m working on a model with distributed data parallel. Referring to this, I got

import os
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def main_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    ...
    model = DDP(model, device_ids=[rank])
    ...
    loss = criterion(y_batch, y_pred)
    dist.destroy_process_group()

if __name__ == "__main__":
    mp.spawn(main_worker, args=(ngpus,), nprocs=ngpus, join=True)

And I’d like to get return loss from main_worker.
I tried this, but it does not work for me.

How can I get returns from a function using distributed data parallel?

Hi, you can use

  1. torch.multiprocessing.SimpleQueue to let the child processes to put the results in the queue.

  2. point-to-point communication functions to send tensors between different distributed processes.

You may want to refer to this thread for more explanation.

Thx! :slight_smile:
In my case, however, I solved using torch.multiprocessing.Pipe as below:

def main_worker(rank, world_size, conn):
    ...
    conn.send(loss)

if __name__ == "__main__":
    parent_conn, child_conn = mp.Pipe()
    mp.spawn(main_worker, args=(ngpus, child_conn,), nprocs=ngpus, join=True)
    losses = []
    while parent_conn.poll():
        losses.append(parent_conn.recv())

Then I can gather all losses from every worker.