How many times do I need to call `new_group` if I want to create m partition of n processes?

For example, I have processes with rank [0, 1, 2, 3], and I want to create 3 partitions of the processes:

  • one is the whole group [[0, 1, 2, 3]], i.e. partition all the processes in one group
  • one is divided group, [[0, 1], [2, 3]]
  • the final one is a different partition, [[0, 2], [1, 3]]

How should I call new_group?

I thought about two ways:

The first:

import torch.distributed as dist
dist.init_process_group(backend='gloo')
group0 = dist.group.WORLD
if dist.get_rank() in [0, 1]:
    group1 = dist.new_group(ranks=[0, 1])
else:
    group1 = dist.new_group(ranks=[2, 3])
if dist.get_rank() in [0, 2]:
    group2 = dist.new_group(ranks=[0, 2])
else:
    group2 = dist.new_group(ranks=[1, 3])

The second:

import torch.distributed as dist
dist.init_process_group(backend='gloo')
group0 = dist.group.WORLD
a, b = dist.new_group(ranks=[0, 1]), dist.new_group(ranks=[2, 3])
if dist.get_rank() in [0, 1]:
    group1 = a
else:
    group1 = b
a, b = dist.new_group(ranks=[0, 2]), dist.new_group(ranks=[1, 3])
if dist.get_rank() in [0, 2]:
    group2 = a
else:
    group2 = b

Seems like the second one works. But the code looks ugly. I want to confirm if this is the right way.