Hi there, I am trying to use torch.distributed.gather in a setup with 4 gpus and 1 node, but I can’t make it work. Where am I wrong ?
In practice, I create a model with DDP, and after computing the loss function on all ranks I want to gather it to rank 0.
My code is something like:
import os
import argparse
import warnings
import numpy as np
import torch
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from torch.nn.parallel import DistributedDataParallel as ddp
def init_run():
parser = argparse.ArgumentParser()
# here add a lot of args, not of interest
opt = parser.parse_args()
return opt
def setup(rank, world_size):
if 'MASTER_ADDR' not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
if rank == 0:
warnings.warn("Set Environ Variable 'MASTER_ADDR'='localhost'")
if 'MASTER_PORT' not in os.environ:
os.environ["MASTER_PORT"] = "29500"
if rank == 0:
warnings.warn("Set Environ Variable 'MASTER_PORT'='29500'")
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def set_models(opt):
setup(opt.rank, opt.world_size)
model = model(opt).to(opt.rank)
model = ddp(model, device_ids=[opt.rank])
return model
def train(rank, world_size, opt):
opt.rank = rank
opt.world_size = world_size
model = set_models(opt)
# criterion and optimizer
criterion = Loss(temperature=opt.temp).to('cuda')
optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
cudnn.benchmark = True
# train loop
for batch_idx in range(opt.num_samples // opt.batch_size):
# forward model
features = .... # some lines of code
loss_f = criterion(features)
loss_f.backward()
optimizer.step()
# gather between all ranks
print(f'RANK: {rank} - GATHERING')
this_value = loss_f.detach()
if opt.rank == 0:
collected = [torch.zeros_like(this_value) for _ in range(world_size)]
dist.gather(gather_list=collected, tensor=this_value, dst=0, group=dist.group.WORLD)
else:
dist.gather(tensor=this_value, dst=0, group=dist.group.WORLD)
print(f'RANK: {rank} - DONE')
# some other code ...
dist.destroy_process_group()
if __name__ == '__main__':
w_size = torch.cuda.device_count()
print(f'Using {w_size} gpus for training model')
mp.spawn(train, args=(w_size, init_run()), nprocs=w_size)
After entering the training loop, I got outputs:
RANK: 0 - GATHERING
RANK: 2 - GATHERING
RANK: 3 - GATHERING
RANK: 1 - GATHERING
and then nothing else. Apparently, the script freezes here, when gather is called.