Strange behavior of HSDP

I followed the tutorial on DeviceMesh to test the following HSDP code. When I printed the weights, I expected to see the same values 4 times. However, instead, I see two sets of repeated values. Does this mean HSDP did not sync the weights along the replicate_group axis?

import torch
import torch.nn as nn

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


# HSDP: MeshShape(2, 2)
mesh_2d = init_device_mesh("cuda", (2, 2), mesh_dim_names=("replicate", "shard"))
model = FSDP(
    ToyModel().cuda(), 
    device_mesh=mesh_2d, 
    sharding_strategy=ShardingStrategy.HYBRID_SHARD,
    use_orig_params=True
)

with FSDP.summon_full_params(model, with_grads=True):
    print(model.net1.weight.data.sum())

From the example, we have the mesh topology of [[0, 1], [2, 3]]:

  • [0, 1] and [2, 3] are the replicate group
  • [0, 2] and [1, 3] are the shard group

So we should see two sets of the same value. Below is the printout from a sample run:

local_rank=0, model.net1.weight.data.sum()=tensor(-0.4711, device='cuda:0')
local_rank=1, model.net1.weight.data.sum()=tensor(-0.4711, device='cuda:1')
local_rank=2, model.net1.weight.data.sum()=tensor(0.2237, device='cuda:2')
local_rank=3, model.net1.weight.data.sum()=tensor(0.2237, device='cuda:3')
1 Like