I plan to use 100 servers to build thousands of environments. I don’t know how to assign a unique rank without sockets. Is there any usage that I don’t know?
In fact, I once ran the same code on pytorch1.6. There was no problem with the startup. But I could only run up to 1000 environments because of the port limit (can see When I use 1024 nodes in rpc, I meet RuntimeError "listen: Address already in use" - #6 by yueyilia). I don’t know why the error occurred now.
I still haven’t found a solution, the following is a min repro.
import os
import time
import tqdm
import socket
import argparse
import logging
import torch
import torch.multiprocessing as mp
import torch.distributed
import torch.distributed.rpc as rpc
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, default="trainer")
parser.add_argument("--address", type=str, default="10.90.224.127", help="")
parser.add_argument("--port", type=str, default="10088", help="")
parser.add_argument("--rank_port", type=int, default="10099", help="")
parser.add_argument("--num_envs_each_server", type=int, default=16, help="")
parser.add_argument("--num_servers", type=int, default=2, help="")
parser.add_argument("--num_envs", type=int, default=32, help="")
parser.add_argument("--world_size", type=int, default=33, help="")
opt = parser.parse_args()
def main():
if opt.mode == "trainer":
mp.spawn(run_worker, args=(opt,), nprocs=1, join=True)
else:
mp.spawn(run_worker, args=(opt,), nprocs=opt.num_envs_each_server, join=True)
def run_worker(idx, opt):
os.environ["MASTER_ADDR"] = opt.address
os.environ["MASTER_PORT"] = opt.port
backend = rpc.BackendType.TENSORPIPE
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=1, rpc_timeout=60
)
if opt.mode == "trainer":
s = socket.socket()
s.bind((opt.address, opt.rank_port))
s.listen(opt.num_envs)
for i in tqdm.trange(opt.num_envs, ascii=True):
c, addr = s.accept()
c.send(str(i).encode("utf-8"))
time.sleep(1)
s.close()
logger = get_logger("agent")
logger.info("init rpc for ppo agent")
rpc.init_rpc(
"ppo_agent",
rank=0,
world_size=opt.world_size,
backend=backend,
rpc_backend_options=rpc_backend_options,
)
# torch.distributed.init_process_group(
# backend="gloo", rank=0, world_size=opt.world_size,
# )
logger.info("end")
else:
while 1:
try:
s = socket.socket()
s.connect((opt.address, opt.rank_port))
rank = s.recv(1024)
rank = int(rank.decode())
s.close()
break
except:
time.sleep(1)
pass
logger = get_logger("env_{}".format(rank))
logger.info("init rpc for env {}".format(rank))
rpc.init_rpc(
"env_{}".format(rank),
rank=1 + rank,
world_size=opt.world_size,
backend=backend,
rpc_backend_options=rpc_backend_options,
)
# torch.distributed.init_process_group(
# backend="gloo", rank=1 + rank, world_size=opt.world_size,
# )
logger.info("env {} is waiting".format(rank))
rpc.shutdown()
def get_logger(name="", level=logging.INFO, stream=True, file=None):
try:
import absl.logging
logging.root.removeHandler(absl.logging._absl_handler)
absl.logging._warn_preinit_stderr = False
except Exception as e:
print("failed to fix absl logging bug", e)
pass
logger = logging.getLogger(name)
logger.setLevel(level)
if stream:
stream_handler = logging.StreamHandler()
stream_formatter = logging.Formatter("%(asctime)s - %(message)s")
stream_handler.setFormatter(stream_formatter)
logger.addHandler(stream_handler)
if file:
path = os.path.join(file, name + ".log")
file_handler = logging.handlers.RotatingFileHandler(
path, "a", 100 * 1024 * 1024, 1, encoding="utf-8"
)
file_formatter = logging.Formatter(
"%(asctime)s %(levelname)s [%(filename)s: %(lineno)d] [%(processName)s: %(process)d] - %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
return logger
if __name__ == "__main__":
main()