Hi
I’ve been training a decoder only transformer model using FSDP till this point where all model parameters, optimizer states etc are sharded across all ranks since the FSDP units wrapping modules shard across a single global process group.
I then moved onto an MoE implementation, where in addition to parameters that are sharded across all ranks, experts are sharded only across subsets of ranks. This is fine as it involves sharding experts through FSDP units belonging to certain sub process groups in addition to the global process group.
I’m still using my original implementation of saving and loading checkpoints in the MoE implementation. The way I’m saving and loading checkpoints is according to the examples in the FSDP docs, using a context manager with SHARDED_STATE_DICT type and having each rank save its shard (which should have common parameters like Attention weights and specific expert parameter shards from sub groups)
Now, when I try to get ranks to load their respective checkpoints when resuming a run I get an error saying local rank used to save doesn’t match rank while loading. When setting up all the processes launched by mp.spawn(…), before anything else, I use torch.cuda.set_device(rank) which should always map process 0 to device 0 and so on.
I also ensure to load the checkpoints after wrapping my model with FSDP units in the same way.
Would appreciate it if someone can share reference material on how saving and loading checkpoints should be done when you have different sharding types (all rank shards and sub process group specific rank shards together).
Thank you!
Reference code snippets
- model and optimizer state_dict extraction while saving
# fsdp sharded state dict type for checkpointing
with FSDP.state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
ShardedStateDictConfig(offload_to_cpu=True),
ShardedOptimStateDictConfig(offload_to_cpu=True),
):
model_state_dict = model.state_dict()
original_optim_state_dict = optimizer.state_dict()
optim_state_dict = FSDP.optim_state_dict(
model=model,
optim=optimizer,
optim_state_dict=original_optim_state_dict,
)
scheduler_state_dict = scheduler.state_dict()
- loading model and optimizer state from saved checkpoints
# Load the checkpoint for this rank
checkpoint = torch.load(checkpoint_path, map_location=f"cuda:{rank}")
# Load state dicts from checkpoint
model_state_dict = checkpoint["model_state_dict"]
optimizer_state_dict = checkpoint["optimizer_state_dict"]
scheduler_state_dict = checkpoint["scheduler_state_dict"]
# Load model state dict using FSDP sharded state dict utility
with FSDP.state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
ShardedStateDictConfig(offload_to_cpu=True),
ShardedOptimStateDictConfig(offload_to_cpu=True),
):
# Load model state dict
model.load_state_dict(model_state_dict)
# Load optimizer state dict
optim_state_dict = FSDP.optim_state_dict_to_load(
model, optimizer, optimizer_state_dict
)
optimizer.load_state_dict(optim_state_dict)
# Load scheduler state
scheduler.load_state_dict(scheduler_state_dict)
- wrapping experts in specific sub-process group specific FSDP units
new_experts_list = nn.ModuleList()
for idx, expert in enumerate(module.experts):
pg, ranks = expert_groups[idx]
if rank in ranks:
new_expert = FSDP(
expert,
process_group=pg,
cpu_offload=CPUOffload(offload_params=False),
device_id=rank,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
),
sharding_strategy=ShardingStrategy.FULL_SHARD,
)
new_experts_list.append(new_expert)
setattr(module, "experts", new_experts_list)
- wrapping all other parameters in global process group specific FSDP units
if isinstance(module, tuple(transformer_layer_cls)):
setattr(
model,
name,
FSDP(
module,
cpu_offload=CPUOffload(offload_params=False),
device_id=rank,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
),
sharding_strategy=ShardingStrategy.FULL_SHARD,
),
)
else:
wrap_non_expert_modules(module, rank, transformer_layer_cls)