Pytorch DDP get stuck in getting free port

I try to get a free port in DDP initialization of PyTorch. However, my code get stuck. The following snippet could repeat my description:

def get_open_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    port = get_open_port()
    os.environ['MASTER_PORT'] = str(port)   # '12345'

    # Initialize the process group.
    dist.init_process_group('NCCL', rank=rank, world_size=world_size)

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 5)

    def forward(self, x):
        print(f'x device={x.device}')
        return self.net1(x)


def demo_basic(rank, world_size):
    setup(rank, world_size)

    logger = logging.getLogger('train')
    logger.setLevel(logging.DEBUG)
    logger.info(f'Running DPP on rank={rank}.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20, 10)  # .to(rank)

    print(f'inputs device={inputs.device}')
    outputs = ddp_model(inputs)
    print(f'output device={outputs.device}')

    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func, world_size):
    mp.spawn(
        demo_func,
        args=(world_size,),
        nprocs=world_size,
        join=True
    )

run_demo(demo_basic, 4)

The function get_open_port is supposed to free the port after invocation. My questions are: 1. How does it happen? 2. How to fix it?

The problem here is that each of the processes end up with a different port due to get_open_port. As a result, the processes can’t rendezvous properly. The MASTER_PORT env variable needs to be the same on all processes and you probably need to choose a fixed port for this. The other option is to find a free port on the master, then communicate this port to all processes using some out of band mechanism to ensure all processes use the same port.

For example, if you’re running all processes on a single host (which is what your example code does). You can call get_open_port in the run_demo function and pass the free port to all processes as an argument to the demo_func method.

2 Likes