Hi, I’m working on a model with distributed data parallel. Referring to this, I got
import os
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
...
model = DDP(model, device_ids=[rank])
...
loss = criterion(y_batch, y_pred)
dist.destroy_process_group()
if __name__ == "__main__":
mp.spawn(main_worker, args=(ngpus,), nprocs=ngpus, join=True)
And I’d like to get return loss
from main_worker
.
I tried this, but it does not work for me.
How can I get returns from a function using distributed data parallel?