DistributedDataParallel init hangs

The following code hangs on the DDP initialisation when multiple GPUs are used. This happens with torchrun or when using a multiprocessing method as described here: Writing Distributed Applications with PyTorch — PyTorch Tutorials 2.5.0+cu124 documentation. A single GPU run works fine.

Example code:

from torchvision.models import resnet18
from torch import nn
import torch
import os

rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
print(f"init process group {rank}")
torch.cuda.set_device(rank)
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
device = torch.device(f"cuda:{rank}")
# torch.distributed.barrier()
model = resnet18().to(device)
print(f"starting {rank}")

new_model = nn.parallel.DistributedDataParallel(model, device_ids=[device])
print(f"done{rank}")

I am running it with this command:

torchrun --nnodes=1 --nproc_per_node=gpu simple.py

And i get this output before it hangs. Eventually it times out. A single gpu run takes about 5 seconds, so it’s not an issue with the size of the model.

init process group 1
init process group 0
init process group 3
starting 2
starting 3
starting 1
starting 0

Adding --rdzv-backend=c10d --rdzv-endpoint=localhost:0 options makes no difference. As there is no error, i don’t know how to debug.
Env:
cuda 12.2, Driver Version: 565.57.01, 4 x A6000 GPU. Torch 2.1.1+cu121

I have had success running on another machine with different GPU/driver version but same CUDA/torch version.