Hi guys,
I was trying to wrap my model with DistributedDataParallel. My model is separated into 2 parts, each parnt runs on one GPU. Thus I followed the Combine DDP with Model Parallelism in official tutorial, but after that I encountered with RuntimeError: Socket Timeout
.
My codebase is basically like this:
# fire tasks on SLURM cluster...
os.environ["MASTER_PORT"] = str(port)
os.environ["MASTER_ADDR"] = str(master_ip)
os.environ["WORLD_SIZE"] = str(n_tasks)
os.environ["RANK"] = str(proc_id)
dist.init_process_group(backend=dist.Backend.NCCL, timeout=timedelta(seconds=30))
# ...
class MyModel(nn.Module)
def __init__(self, ..., device0, device1):
# ...
self.part_1.to(device0)
self.part_2.to(device1)
# task0 get GPU{0, 1}, task1 get GPU(2, 3)...
d0 = torch.device(f"cuda:{rank * 2}")
d1 = torch.device(f"cuda:{rank * 2 + 1}")
model = MyModel(..., d0, d1)
# not all parameters are used in each iteration
ddp_model = DistributedDataParallel(model, , find_unused_parameters=True)
# ...
Invoking DDP did not raise any error, however after the timeout
(30s in my setting), I encountered with following error:
Traceback (most recent call last):
File "../tools/train_val_classifier.py", line 332, in <module>
main()
File "../tools/train_val_classifier.py", line 103, in main
model, model_without_ddp = get_ddp_model(model, devices=(fp_device, q_device))
File ".../quant_prob/utils/distributed.py", line 120, in get_ddp_model
ddp_model = DistributedDataParallel(model, device_ids=devices, find_unused_parameters=True)
File "/envs/r0.3.0/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 286, in __init__
self.broadcast_bucket_size)
File "/envs/r0.3.0/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 410, in _dist_broadcast_coalesced
dist._dist_broadcast_coalesced(self.process_group, tensors, buffer_size, False)
RuntimeError: Socket Timeout
Seems that this error came from DDP implementation. I denifitely sure that I followed the official tutorial, and GPUs assiged to each tasks did not overlap. How can I fix this? Thank you so much~