Hi, I’m working on a model with distributed data parallel. Referring to this, I got
import os import torch import torch.nn as nn import torch.multiprocessing as mp import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def main_worker(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12345' dist.init_process_group('nccl', rank=rank, world_size=world_size) ... model = DDP(model, device_ids=[rank]) ... loss = criterion(y_batch, y_pred) dist.destroy_process_group() if __name__ == "__main__": mp.spawn(main_worker, args=(ngpus,), nprocs=ngpus, join=True)
And I’d like to get return
I tried this, but it does not work for me.
How can I get returns from a function using distributed data parallel?