I followed the tutorial on DeviceMesh to test the following HSDP code. When I printed the weights, I expected to see the same values 4 times. However, instead, I see two sets of repeated values. Does this mean HSDP did not sync the weights along the replicate_group axis?
import torch
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
# HSDP: MeshShape(2, 2)
mesh_2d = init_device_mesh("cuda", (2, 2), mesh_dim_names=("replicate", "shard"))
model = FSDP(
ToyModel().cuda(),
device_mesh=mesh_2d,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
use_orig_params=True
)
with FSDP.summon_full_params(model, with_grads=True):
print(model.net1.weight.data.sum())