What is the best practice to send/recv multiple tensors across DDP ranks?

Hi, recently I want to do some Parameter Synchronization on a DDP wrapped model (with Megatron model parallel) and find that I need to send/recv multiple parameter tensors between arbitrary two ranks.

As mentioned in #173542, I attempt to implement this feature with batch_isend_irecv. For each parameter required by remote in the local rank, I send/recv it with a dist.P2POp.

The implementation seems work fine on a single node with 8 GPUs. But when I scale it to multiple nodes case, the program hangs. Finally, I solve the issue by concating parameter tensors into one buffer, so that for rank i → rank j, there is at most 1 send op and 1 recv op. (Though I don’t how this strategy works)

However, this method requires a lots of extra GPU memory and may cause OOM.

Therefore, my question is, is there any best strategy to send/recv multiple tensors (with different shape) across N ranks?

P.S. Here is a sample code to reproduce hang. (use --test-type random or --test-type sorted)
Environment: NGC 25.02

import os
import torch
import time
import random
from functools import partial
import logging
from argparse import ArgumentParser
from collections import defaultdict
from torch import distributed as dist

def add_args(parser):
    group = parser.add_argument_group(title='Distributed CKPT Convertor')
    group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')))
    group.add_argument('--test-type', type=str, default='random', choices=['random', 'sorted', 'merged'])
    return parser

def random_send_recv(
    args, T: torch.Tensor
):
    """
        Given [world_size, world_size] matrix T
        generate T[i][j] send ops from rank i to j, and T[j][i] recv ops from j to i
        the payload of each op is an integer of (i + 1) * (j + 1)
        run these ops in a batch_isend_irecv
    """
    send_ops = []
    for j in range(args.world_size):
        for _ in range(T[args.rank][j]):
            send_ops.append(dist.P2POp(
                dist.isend,
                torch.ones(1, device='cuda') * (args.rank + 1) * (j + 1), 
                peer=j
            ))

    recv_ops = []
    recv_datas = defaultdict(list)
    for j in range(args.world_size):
        for _ in range(T[j][args.rank]):
            recv_datas[j].append(torch.empty(1, device='cuda'))
            recv_ops.append(dist.P2POp(
                dist.irecv,
                recv_datas[j][-1], 
                peer=j
            ))
    random.shuffle(send_ops)
    random.shuffle(recv_ops)

    return recv_datas, send_ops + recv_ops


def sorted_send_recv(
    args, T: torch.Tensor
):
    """
        Given [world_size, world_size] matrix T
        generate T[i][j] send ops from rank i to j, and T[j][i] recv ops from j to i
        the payload of each op is an integer of (i + 1) * (j + 1)
        run these ops in a batch_isend_irecv
    """
    send_ops = []
    for j in range(args.world_size):
        for _ in range(T[args.rank][j]):
            send_ops.append(dist.P2POp(
                dist.isend,
                torch.ones(1, device='cuda') * (args.rank + 1) * (j + 1), 
                peer=j
            ))

    recv_ops = []
    recv_datas = defaultdict(list)
    for j in range(args.world_size):
        for _ in range(T[j][args.rank]):
            recv_datas[j].append(torch.empty(1, device='cuda'))
            recv_ops.append(dist.P2POp(
                dist.irecv,
                recv_datas[j][-1], 
                peer=j
            ))
    # NOTE: always receive param in ascending order
    return recv_datas, send_ops + recv_ops

def merged_send_recv(
    args, T: torch.Tensor, is_shuffled: bool=False
):
    """
        Given [world_size, world_size] matrix T
        generate 1 send op from rank i to j, and 1 recv ops from j to i
        the payload of each op is an tensor of (i + 1) * (j + 1) of size T[i][j]
        run these ops in a batch_isend_irecv
    """
    send_ops = []
    for j in range(args.world_size):
        send_ops.append(dist.P2POp(
            dist.isend,
            torch.ones(T[args.rank][j], device='cuda') * (args.rank + 1) * (j + 1), 
            peer=j
        ))

    recv_ops = []
    recv_datas = defaultdict(list)
    for j in range(args.world_size):
        recv_datas[j].append(torch.empty(T[j][args.rank], device='cuda'))
        recv_ops.append(dist.P2POp(
            dist.irecv,
            recv_datas[j][-1], 
            peer=j
        ))
    # NOTE: always receive param in ascending order
    return recv_datas, send_ops + recv_ops

if __name__ == '__main__':
    dist.init_process_group(
        backend='nccl'
    )
    parser = ArgumentParser()
    parser = add_args(parser)
    args = parser.parse_args()
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
    torch.cuda.set_device(args.local_rank)
    torch.manual_seed(42)
    if args.rank == 0:
        if args.test_type == 'random':
            print(f"random send recv")
        elif args.test_type == 'sorted':
            print(f"sorted send recv")
        elif args.test_type == 'merged':
            print(f"sorted send recv")
    for i in range(2):
        T = (1 + torch.randperm(args.world_size ** 2)).reshape(args.world_size, -1).cuda()
        if args.test_type == 'sorted':
            datas, ops = sorted_send_recv(args, T)
        elif args.test_type == 'merged':
            datas, ops = merged_send_recv(args, T)
        else:
            datas, ops = random_send_recv(args, T)
        reqs = dist.batch_isend_irecv(ops)
        print(f'Iter {i} RANK {args.rank} commit requests')
        for req in reqs:
            req.wait()
        print(f"Iter {i} RANK {args.rank} Work finish")
        for src_rank, data_list in datas.items():
            if (torch.stack(data_list) != (args.rank + 1) * (src_rank + 1)).any():
                print(f"Iter {i} RANK {args.rank}: rank {src_rank} -> rank {args.rank} corrupted {data_list}")
        print(f"Iter {i} RANK {args.rank} joined")
        torch.cuda.synchronize()
        dist.barrier()
export TORCH_NCCL_BLOCKING_WAIT=1
torchrun --nproc_per_node 8 --nnodes 1 --master_addr localhost --master_port 6000 test.py --test-type random