How can I gather tensors from specific ranks

I want to gather tensors from specific ranks in each rank (For example, I want gather ranks=[0,1] in rank0&rank1, and gather ranks=[2,3] in rank2&3). I implement by initial new group:

import os
import random
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist

from torch.multiprocessing import Process

from absl import flags
from absl import app


flags.DEFINE_integer('nodes_num', 1, 'machine num')
flags.DEFINE_integer('ngpu', 4, 'ngpu per node')
flags.DEFINE_integer('world_size', 4, 'FLAGS.nodes_num*FLAGS.ngpu')
flags.DEFINE_integer('node_rank', 0, 'rank of machine, 0 to nodes_num-1')
flags.DEFINE_integer('rank', 0, 'rank of total threads, 0 to FLAGS.world_size-1, will be re-compute in main_worker func')

def group_gather(tensor, rank, ngpu_per_node):
    #ranks = [0,1]
    if rank == 0 or rank == 1:
        ranks = [0,1]
    if rank == 2 or rank == 3:
        ranks = [2,3]
    print('ranks: ', ranks)
    group = dist.new_group(ranks = ranks)
    tensors_gather = [torch.ones_like(tensor) for _ in range(2)]
    torch.distributed.all_gather(tensors_gather, tensor, group, async_op=False)
    output =, dim=0)
    print('gather out shape: ', output.shape)
    return output

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.fc = nn.Linear(3,2)

    def forward(self, x, rank, ngpu_per_node):
        x_gather = group_gather(x, rank, ngpu_per_node)
        out = self.fc(x_gather)
        return out 

def main(argv):
    del argv
    os.environ['MASTER_ADDR'] = ''
    os.environ['MASTER_PORT'] = str(random.randint(1,100000))
    mp.spawn(main_worker, nprocs=FLAGS.ngpu, args=())

def main_worker(gpu_rank):
    FLAGS._parse_args(FLAGS.read_flags_from_files(['--flagfile=./tmp.cfg']), True)
    FLAGS.rank = FLAGS.node_rank * FLAGS.ngpu + gpu_rank # rank among FLAGS.world_size
    model = ToyModel()
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[gpu_rank])

    x = torch.randn(4,3).cuda()
    model(x, FLAGS.rank, FLAGS.ngpu)

if __name__ == '__main__':

In group_gather(…), I init new group according to thread’s rank.

But this scripts can not always work well, It may crash in some times, and raise error:

Traceback (most recent call last):
  File "/root/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/", line 19, in _wrap
    fn(i, *args)
  File "/root/test_distcomm/", line 78, in main_worker
    model(x, FLAGS.rank, FLAGS.ngpu)
  File "/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/root/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/", line 447, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/root/test_distcomm/", line 48, in forward
    x_gather = group_gather(x, gpu_rank, ngpu_per_node)
  File "/root/anaconda3/lib/python3.6/site-packages/torch/autograd/", line 49, in decorate_no_grad
    return func(*args, **kwargs)
  File "/root/test_distcomm/", line 35, in group_gather
    torch.distributed.all_gather(tensors_gather, tensor, group, async_op=False)
  File "/root/anaconda3/lib/python3.6/site-packages/torch/distributed/", line 1153, in all_gather
    work = group.allgather([tensor_list], [tensor])
RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:410, unhandled system error, NCCL version 2.4.8

I think the logic in code is correct, and I can not figure out where is wrong.

I run this code with 4 nvidia-t4 gpus with cuda10.1, my pytorch version is 1.4.0.

You can simply run this code with ‘python’ (may need pip install absl-py)

If I set ranks in group_gather func as [0,1] consistently, this code can work well all the time

The new_group API requires all processes to call with the same ranks argument if even if they do not participate in the new group. See the API doc here:

In the code above, the following code breaks the above assumption.

    if rank == 0 or rank == 1:
        ranks = [0,1]
    if rank == 2 or rank == 3:
        ranks = [2,3]
    print('ranks: ', ranks)
    group = dist.new_group(ranks = ranks)

It needs to be modified to the following:

    g1 = dist.new_group(ranks = [0, 1])
    g2 = dist.new_group(ranks = [2, 3])
    # check rank to see if the current process should use g1 or g2

Yes, it works well now! Thanks very much