I try to get a free port in DDP initialization of PyTorch. However, my code get stuck. The following snippet could repeat my description:
def get_open_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
port = get_open_port()
os.environ['MASTER_PORT'] = str(port) # '12345'
# Initialize the process group.
dist.init_process_group('NCCL', rank=rank, world_size=world_size)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 5)
def forward(self, x):
print(f'x device={x.device}')
return self.net1(x)
def demo_basic(rank, world_size):
setup(rank, world_size)
logger = logging.getLogger('train')
logger.setLevel(logging.DEBUG)
logger.info(f'Running DPP on rank={rank}.')
# Create model and move it to GPU.
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) # optimizer takes DDP model.
optimizer.zero_grad()
inputs = torch.randn(20, 10) # .to(rank)
print(f'inputs device={inputs.device}')
outputs = ddp_model(inputs)
print(f'output device={outputs.device}')
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_func, world_size):
mp.spawn(
demo_func,
args=(world_size,),
nprocs=world_size,
join=True
)
run_demo(demo_basic, 4)
The function get_open_port
is supposed to free the port after invocation. My questions are: 1. How does it happen? 2. How to fix it?