Subprocess groups w/ DeviceMesh Blocking

I am trying to do sharding on a subset of running processes. With my current setup, I am running 4 processes on a single node and only want processes 0 and 1 to be involved with the Shard() operation. Initially, without a barrier(), processes 0 and 1 would stall immediately. With the barrier() processes 2 and 3 stall.

I recognize the issue is likely that the new_group actually requires that all processes in the main group (i.e. all processes that are part of the distributed job) enter this function, even if they are not going to be members of the group. Per this post Shard Tensor Across Specific Ranks, I solution could be changing the way DeviceMesh is initialized. However, I find this to be unideal in the sense that with what I am currently trying to do, each operation (i.e. shard, etc.) isn’t aware of all the processes that are part of the job, but rather just the processes that are a part of it. In other words, I can’t just initialize the devicemesh as (2, 2) and get a submesh out of it. I need to just have a devicemesh with the required ranks for an operation. Any workaround for this?

local_rank = int(os.environ.get("LOCAL_RANK", 0))
dist.init_process_group(backend='nccl')
torch.cuda.set_device(local_rank)
global_rank = dist.get_rank()

print(f"[Rank {global_rank}] Local Rank = {local_rank}, Device = cuda:{local_rank}")


mesh_ranks = [0, 1]
if global_rank in mesh_ranks:
    device_mesh = DeviceMesh("cuda", mesh_ranks)

    x = torch.tensor([[1, 2], [3, 4]], device=f"cuda:{local_rank}")
    dt = distribute_tensor(x, device_mesh, [Shard(0)])

    print(f"[Rank {global_rank}] Sharded tensor: {dt}")
    print(f"[Rank {global_rank}] Reconstructed: {dt.full_tensor()}")
else:
    print(f"[Rank {global_rank}] Skipping mesh participation.")

dist.barrier()

print(f"[Rank {global_rank}] Lets continue.")



Output:
[Rank 2] Local Rank = 2, Device = cuda:2
[Rank 2] Skipping mesh participation.
[Rank 0] Local Rank = 0, Device = cuda:0
[Rank 1] Local Rank = 1, Device = cuda:1
[Rank 3] Local Rank = 3, Device = cuda:3
[Rank 3] Skipping mesh participation.
[Rank 0] Sharded tensor: DTensor(local_tensor=tensor([[1, 2]], device='cuda:0'), device_mesh=DeviceMesh([0, 1]), placements=(Shard(dim=0),))
[Rank 1] Sharded tensor: DTensor(local_tensor=tensor([[3, 4]], device='cuda:1'), device_mesh=DeviceMesh([0, 1]), placements=(Shard(dim=0),))
[Rank 0] Reconstructed: tensor([[1, 2],
        [3, 4]], device='cuda:0')
[Rank 0] Lets continue.
[Rank 1] Lets continue.

This is indeed needed, any reason why you cannot do that in the SPMD style?

I have some specific use cases that can hardly fit the SPMD style. My program follows a producer-consumer style, where some devices are responsible for creating data, and other devices use the data for training. In this case I want the producers and consumers to be basically independent except for the limited communication.