Pytorch Profiler with DDP

Hi

I’m trying to use the torch.profiler tool but am running into issues using concurrently with torch.distributed. My training script is organized like this:

python

def train(args.local_world_size, args.local_rank, args):
   # load data
   # set up model
   model = DDP(model)
   ...
   # train
   # evaluate

def main(local_world_size, local_rank, args):
    # These are the parameters used to initialize the process group
    home = os.path.expanduser("~")
    init_file = f"{home}/shared_init_file"

    env_dict = {
        key: os.environ[key]
        for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE")
    }

    print(f"[{os.getpid()}] Initializing process group with: {env_dict}")

    dist.init_process_group(backend="nccl",
                            rank=local_rank,
                            world_size=local_world_size,
                            init_method=f'file://{init_file}')

    print(
        f"[{os.getpid()}] world_size = {dist.get_world_size()}, "
        + f"rank = {dist.get_rank()}, backend={dist.get_backend()}"
    )

    # train model
    train(args.local_world_size, args.local_rank, args)

    # Tear down the process group
    dist.barrier() # first synchronize devices
    dist.destroy_process_group() # then destroy

    # removed shared process file
    if os.path.exists(init_file):
        os.remove(init_file)


if __name__ == "__main__":
   # get CLI arguments
   parser = argparse.ArgumentParser()
   ...
   ...
   args = parser.parse_args()
   main(args.local_world_size, args.local_rank, args)

Irrespective if I put the profiler in main() or train(), the script hangs at the dist.init_process_group step of main(). The script runs correctly when removing all lines associated with the profiler. What is the correct way to utilize the profiler when using torch.distributed?

Thanks.

Kristian

This means even if it hasn’t started profiler context, it still hangs at init_process_group? Which PyTorch version are you using?

For sample code, this worked for me: https://github.com/mrshenli/ptd_benchmark/blob/62f8c89bab7d61607dc101abfb37480e27ee4616/trainer.py#L375-L384