When I use RPC on more than one machine, the code will get stuck in init_rpc. I try opt.num_servers == 1, it works when opt.num_envs_each_server<30. I use python 3.7.10 and pytorch 1.8.1. Is there something wrong with this code?
# server
def main():
opt = Options().parse()
if opt.num_servers == 1:
num = NUM_TRAINER_PROCESSES + opt.num_envs_each_server
else:
num = NUM_TRAINER_PROCESSES
mp.spawn(run_worker, args=(opt,), nprocs=num, join=True)
def run_worker(idx, opt):
os.environ["MASTER_ADDR"] = opt.address
os.environ["MASTER_PORT"] = opt.port
backend = rpc.BackendType.TENSORPIPE
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=1, rpc_timeout=60
)
if idx == 0:
s = socket.socket()
s.bind((opt.address, opt.env_port))
s.listen(opt.num_envs)
for i in tqdm.trange(opt.num_envs, ascii=True):
c, addr = s.accept()
c.send(str(i).encode("utf-8"))
time.sleep(1)
s.close()
logger = utils.get_logger("agent")
logger.info("init rpc for ppo agent")
rpc.init_rpc(
"ppo_agent",
rank=0,
world_size=opt.world_size,
backend=backend,
rpc_backend_options=rpc_backend_options,
)
else:
while 1:
try:
s = socket.socket()
s.connect((opt.address, opt.env_port))
rank = s.recv(1024)
rank = int(rank.decode())
s.close()
break
except:
time.sleep(1)
pass
logger = utils.get_logger("env_{}".format(rank))
logger.info("init rpc for env {}".format(rank))
rpc.init_rpc(
"env_{}".format(rank),
rank=NUM_TRAINER_PROCESSES + rank,
world_size=opt.world_size,
backend=backend,
rpc_backend_options=rpc_backend_options,
)
logger.info("env {} is waiting".format(rank))
rpc.shutdown()
# client
def main():
opt = Options().parse()
mp.spawn(run_worker, args=(opt,), nprocs=opt.num_envs_each_server, join=True)
def run_worker(idx, opt):
os.environ["MASTER_ADDR"] = opt.address
os.environ["MASTER_PORT"] = opt.port
backend = rpc.BackendType.TENSORPIPE
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=1, rpc_timeout=60
)
logger = utils.get_logger("{}".format(idx))
while 1:
try:
s = socket.socket()
s.connect((opt.address, opt.env_port))
rank = s.recv(1024)
rank = int(rank.decode())
s.close()
break
except Exception as e:
logger.info(e)
time.sleep(1)
pass
logger.info("init rpc for env {}".format(rank))
rpc.init_rpc(
"env_{}".format(rank),
rank=NUM_TRAINER_PROCESSES + rank,
world_size=opt.world_size,
backend=backend,
rpc_backend_options=rpc_backend_options,
)
logger.info("env {} is waiting".format(rank))
rpc.shutdown()