Hi
I’m trying to use the torch.profiler
tool but am running into issues using concurrently with torch.distributed
. My training script is organized like this:
python
def train(args.local_world_size, args.local_rank, args):
# load data
# set up model
model = DDP(model)
...
# train
# evaluate
def main(local_world_size, local_rank, args):
# These are the parameters used to initialize the process group
home = os.path.expanduser("~")
init_file = f"{home}/shared_init_file"
env_dict = {
key: os.environ[key]
for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
}
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
dist.init_process_group(backend="nccl",
rank=local_rank,
world_size=local_world_size,
init_method=f'file://{init_file}')
print(
f"[{os.getpid()}] world_size = {dist.get_world_size()}, "
+ f"rank = {dist.get_rank()}, backend={dist.get_backend()}"
)
# train model
train(args.local_world_size, args.local_rank, args)
# Tear down the process group
dist.barrier() # first synchronize devices
dist.destroy_process_group() # then destroy
# removed shared process file
if os.path.exists(init_file):
os.remove(init_file)
if __name__ == "__main__":
# get CLI arguments
parser = argparse.ArgumentParser()
...
...
args = parser.parse_args()
main(args.local_world_size, args.local_rank, args)
Irrespective if I put the profiler in main()
or train()
, the script hangs at the dist.init_process_group
step of main()
. The script runs correctly when removing all lines associated with the profiler. What is the correct way to utilize the profiler when using torch.distributed
?
Thanks.
Kristian