What's the performance difference between isend/irecv and batch_isend_irecv

Hi dear developers,

Recently, I read the context parallelism code of attention in TransformerEngine and found that it uses a ring P2P communication architecture to transfer data, the code below is the core part of the communication that sends and receives a tensor at a time:

    send_recv_ops = []

    if batch_p2p_comm:
        if rank % 2 == 0:
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.P2POp(torch.distributed.irecv,
                                              recv_tensor,
                                              recv_src,
                                              cp_group)
            send_op = torch.distributed.P2POp(torch.distributed.isend,
                                              send_tensor,
                                              send_dst,
                                              cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = torch.distributed.batch_isend_irecv(send_recv_ops)
    else:
        if rank % 2 == 0:
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_recv_ops.append(send_op)
            send_recv_ops.append(recv_op)
        else:
            recv_op = torch.distributed.irecv(recv_tensor, recv_src, cp_group)
            send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
            send_recv_ops.append(recv_op)
            send_recv_ops.append(send_op)
        send_recv_reqs = send_recv_ops

    return send_recv_reqs

I noticed a flag called batch_p2p_comm (default is False) which enables batch_isend_irecv() instead of the individual isend() and irecv() calls. I enabled this flag to see if it could accelerate the communication, but after profiling, I found that the communication time using single isend() and irecv() was nearly the same as using batch_isend_irecv().

Given these results, I am curious about the differences between batch_isend_irecv() and isend()/irecv(), since both support asynchronous communication. Can batch_isend_irecv() improve communication speed?

Any help would be appreciated, thank you very much.

I did a relevant experiment a while back, probably on H100 hardware (8 gpus via nvlink):

This isn’t fully conclusive, i didn’t test various versions of NCCL, different GPU types, and importantly, larger clusters (over ethernet or infiniband).

2 Likes

I guess this could have relationship with ncclgroup…?