FSDP/HSDP with `device_mesh` multiple replica intra node

It’s somewhat related to [FSDP] HYBRID_SHARD Apply FULL_SHARD across multiple nodes instead of just intar-node · Issue #117470 · pytorch/pytorch · GitHub but in the opposite direction:

  • Replicas within a node and across node(s)
  • Sharding within a node but only to a limited number of devices

For example, if we had 2 nodes with 8 GPUs each, I’d like to have FSDP/HSDP with 4GPUs for sharding and 4 replicas (2 within a node). I’d assume that setting device_mesh to 4x4 would (maybe?) do this. So, my question is the following I guess:

  • Is it guaranteed to init sharding within a node first if I create a device mesh, i.e. sharding inter nodes is avoided?
1 Like

@vasqu

When you create a 2D DeviceMesh and pass it to HSDP, it always assumes the outer dimension is the replicate group and the inner dimension is the shard group. Under the hood, it is just using the process group (code pointer here: pytorch/torch/distributed/fsdp/_init_utils.py at main · pytorch/pytorch · GitHub). DeviceMesh is just helping initialize the process group easily.

1 Like

@irisz

Sorry, just for clarification as I’m not familiar with the device mesh initialization (/abstraction), I’d expand on my example. Assuming I have node 1 with ranks 0 to 7 (8 in total) and node 2 with ranks 8 to 15 (also 8 in total), I’d then expect with a (4,4) device mesh to get the following pgs:

  • [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15] (sharding pgs)
  • [0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15] (replication pgs)

Is that correct? Thanks in advance!

Edit: In essence, I want to achieve this:


Maybe the graphic can help here.

Yes. This is right. You would get this:

  • [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15] (sharding pgs)
  • [0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15] (replication pgs)

To get this, you would want to initialize a DeviceMesh like this:

mesh_2d = init_device_mesh("cuda", (4, 4), mesh_dim_names=("replicate", "shard"))

Here is a bit information: Getting Started with DeviceMesh — PyTorch Tutorials 2.4.0+cu121 documentation

2 Likes

Thank you! I think I got it now.

1 Like