Anybody can clarify why I keep getting cuda error when using torch.nn.parallel.DistributedDataParallel?
MWE:
if torch.cuda.device_count() > 1:
torch.distributed.init_process_group(backend='nccl')
model = torch.nn.parallel.DistributedDataParallel(model).to(device)
executing script:
python -m torch.distributed.launch main.py
Am I using it incorrectly or am I missing something else?
You have to use DistributedDataParallel with as many processes as the value of world_size. If you specify rank=1 on a single process, it will hang around waiting for the process with rank=0 to start.
Check out the launch utility for easy launching of multiple processes. You can use it like this: