Strange behavior of HSDP

I followed the tutorial on DeviceMesh to test the following HSDP code. When I printed the weights, I expected to see the same values 4 times. However, instead, I see two sets of repeated values. Does this mean HSDP did not sync the weights along the replicate_group axis?

import torch
import torch.nn as nn

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


# HSDP: MeshShape(2, 2)
mesh_2d = init_device_mesh("cuda", (2, 2), mesh_dim_names=("replicate", "shard"))
model = FSDP(
    ToyModel().cuda(), 
    device_mesh=mesh_2d, 
    sharding_strategy=ShardingStrategy.HYBRID_SHARD,
    use_orig_params=True
)

with FSDP.summon_full_params(model, with_grads=True):
    print(model.net1.weight.data.sum())

From the example, we have the mesh topology of [[0, 1], [2, 3]]:

  • [0, 1] and [2, 3] are the replicate group
  • [0, 2] and [1, 3] are the shard group

So we should see two sets of the same value. Below is the printout from a sample run:

local_rank=0, model.net1.weight.data.sum()=tensor(-0.4711, device='cuda:0')
local_rank=1, model.net1.weight.data.sum()=tensor(-0.4711, device='cuda:1')
local_rank=2, model.net1.weight.data.sum()=tensor(0.2237, device='cuda:2')
local_rank=3, model.net1.weight.data.sum()=tensor(0.2237, device='cuda:3')
1 Like

My team has been experimenting with something similar, and facing some inconsistencies. Our toy setup includes 2 nodes, with 4 GPUs each and 2 replicas total. That makes for 4 shards per replica (i.e. 1 replica per node). In essence, we used the same code as @yenchenlin, except we had init_device_mesh("cuda", (2, 4)) instead of (2,2).

We’re noticing the same thing as @yenchenlin when we initialize the models differently before wrapping with FSDP, then wrap with FSDP and print the net1 sum, we see 2 unique values being printed 4 times. This makes us wonder if HSDP is really syncing weights, or not. However, if we fix our seed to be consistent across all ranks before initializing our toy model, then the printout of model.net1.weight.data.sum() within the FSDP.summon_full_params() context is the same across all ranks.

It was also our understanding that FSDP.summon_full_params() enabled each rank to reach the respective parameters of its entire replica. And thus, the ranks that report the same sum would actually be part of the same replica, but this understanding doesn’t seem to align with your printouts @irisz , unless my understanding of “replicate” and “shard” groups is incorrect (which was that a shard group makes up complete replica of the model).

Can you provide us with some info (and sources if available) on whether or not PyTorch HSDP syncs weights across replicas or not?