For example, I have processes with rank [0, 1, 2, 3]
, and I want to create 3 partitions of the processes:
- one is the whole group
[[0, 1, 2, 3]]
, i.e. partition all the processes in one group - one is divided group,
[[0, 1], [2, 3]]
- the final one is a different partition,
[[0, 2], [1, 3]]
How should I call new_group
?
I thought about two ways:
The first:
import torch.distributed as dist
dist.init_process_group(backend='gloo')
group0 = dist.group.WORLD
if dist.get_rank() in [0, 1]:
group1 = dist.new_group(ranks=[0, 1])
else:
group1 = dist.new_group(ranks=[2, 3])
if dist.get_rank() in [0, 2]:
group2 = dist.new_group(ranks=[0, 2])
else:
group2 = dist.new_group(ranks=[1, 3])
The second:
import torch.distributed as dist
dist.init_process_group(backend='gloo')
group0 = dist.group.WORLD
a, b = dist.new_group(ranks=[0, 1]), dist.new_group(ranks=[2, 3])
if dist.get_rank() in [0, 1]:
group1 = a
else:
group1 = b
a, b = dist.new_group(ranks=[0, 2]), dist.new_group(ranks=[1, 3])
if dist.get_rank() in [0, 2]:
group2 = a
else:
group2 = b
Seems like the second one works. But the code looks ugly. I want to confirm if this is the right way.