Hi,
I’m finetuning a multimodal LLM and during this process, I encounter the following error when attempting to save the checkpoint. More particularly, I can save the model normally but when the optimizer states are saved, the following error occurs:
RuntimeError: Detected mismatch between collectives on ranks. Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=549, OpType=ALLGATHER, TensorShape=[0], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 2 is running collective: CollectiveFingerPrint(SequenceNumber=549, OpType=ALLGATHER, TensorShape=[183971584], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects: Tensor Tensor shapes: 0vs 183971584
This is the full traceback of tensor shape mismatch when saving fsdp optimizer states:
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2356, in _inner_training_loop
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2807, in _maybe_log_save_evaluate
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2890, in _save_checkpoint
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 3001, in _save_optimizer_and_scheduler
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 185, in save_fsdp_optimizer
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1828, in optim_state_dict
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1253, in _optim_state_dict_impl
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1396, in _optim_state_dict
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1657, in _gather_orig_param_state
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1593, in _all_gather_optim_state
The packages version is:
torch==2.1.2
numpy==1.26.4
transformers==4.43.1
This is the full details I trace the error in functions listed in the traceback. It’s quite long but it leads to what I believe a possible reason for error above and my question below:
Summary
I began digging into the codebase, starting with torch.distributed.fsdp
, to find out the cause as follows:
- First of all, the error stems from this function: pytorch/torch/distributed/fsdp/_optim_utils.py at v2.1.2 · pytorch/pytorch · GitHub
It happens that when execution enters this line
work = dist.all_gather(
tensors, local_state, group=fsdp_state.process_group, async_op=True
)
tensors
is a list where certain elements are 0 and others are non-zero, as raised by the error above. When I printed out object_state.tensors
:
2025-02-22 17:02:42,065 - root - DEBUG - rank 0, object_state.tensors: {}, name: exp_avg, info: None
2025-02-22 17:02:42,065 - root - DEBUG - rank 1, object_state.tensors: {}, name: exp_avg, info: None
2025-02-22 17:02:42,066 - root - DEBUG - rank 2, object_state.tensors: {'exp_avg': _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32), 'exp_avg_sq': _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32)}, name: exp_avg, info: _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32)
2025-02-22 17:02:42,066 - root - DEBUG - rank 3, object_state.tensors: {'exp_avg': _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32), 'exp_avg_sq': _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32)}, name: exp_avg, info: _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32)
It can be seen that on the 2 ranks 0 and 1, the tensors are empty. Since these object_state.tensors
are gathered in object_list
from processes in the process group via:
dist.all_gather_object(object_list, processed_state, group=fsdp_state.process_group)
It seems that the processed_state
for ranks 0 and 1 is empty (StateInfo({}, {}, {})
). And this is caused by empty optim_state
if you look at the for loop in the beginning of _all_gather_optim_state()
.
2025-02-22 17:02:42,060 - root - DEBUG - @tcm: In _all_gather_optim_state(): optim_state: {}
- After that, I look into the second function in the traceback: pytorch/torch/distributed/fsdp/_optim_utils.py at v2.1.2 · pytorch/pytorch · GitHub
Theoptim_state
dictionary in this function is still empty.
2025-02-22 17:02:42,060 - root - DEBUG - @tcm: In _gather_orig_param_state(): optim_state: {}
- I continued looking into the next function in the traceback: pytorch/torch/distributed/fsdp/_optim_utils.py at v2.1.2 · pytorch/pytorch · GitHub
In this function, I focused on the following part:
# Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
# across ranks
for optim_state_key in all_optim_state_keys:
param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
optim_state_key, None
)
if param_key is None:
assert use_orig_params, (
"If use_orig_params is False, we must be able to find the "
f"corresponding param id. {optim_state_key} {param_key}"
)
if not optim_state_key.is_fsdp_managed:
continue
if optim_state_key.is_fsdp_managed:
# If there are multiple unflat_param_names (not use_orig_params),
# they share the same FSDPParamInfo. So the first unflat_param_name
# is sufficient to fetch the FSDPParamInfo.
fqn = optim_state_key.unflat_param_names[0]
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
if use_orig_params:
state = (
{} if param_key is None else optim_state_dict["state"][param_key]
)
unflat_state = [
_gather_orig_param_state(
fsdp_param_info,
fqn,
state,
shard_state,
)
]
The problem is that param_key
is None
which leads to empty ‘state’ when passed to _gather_orig_param_state()
:
state = ({} if param_key is None else optim_state_dict["state"][param_key])
The param_key
is None
because optim_state_key_to_param_key
dictionary is empty:
2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {} # rank 0 or 1
2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {_OptimStateKey(unflat_param_names=('lm_head.weight',), is_fsdp_managed=True): 2} # rank 2 or 3
2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {_OptimStateKey(unflat_param_names=('model.mm_projector.0.weight',), is_fsdp_managed=True): 0, _OptimStateKey(unflat_param_names=('model.mm_projector.2.weight',), is_fsdp_managed=True): 1, _OptimStateKey(unflat_param_names=('model.mm_projector.0.bias',), is_fsdp_managed=True): 3, _OptimStateKey(unflat_param_names=('model.mm_projector.2.bias',), is_fsdp_managed=True): 4} # rank 2 or 3
To understand why the dict optim_state_key_to_param_key
is empty, I looked into the function: pytorch/torch/distributed/fsdp/_optim_utils.py at v2.1.2 · pytorch/pytorch · GitHub
Here, if we look at the for loop in the beginning:
for param_key, param in param_key_to_param.items():
# Do not include parameters without state to avoid empty mappings
# just like in normal `torch.optim.Optimizer.state_dict()`
if param_key not in optim_state_dict["state"]:
continue
optim_state_dict["state"]
is empty so the iteration is skipped, causing optim_state_key_to_param_key
to not be updated.
2025-02-22 17:02:42,041 - root - DEBUG - @tcm: In _map_param_key_to_optim_keys(): optim_state_dict["state"]: {} # in empty ranks such as 0
2025-02-22 17:02:28,436 - root - DEBUG - @tcm: In _map_param_key_to_optim_keys(): optim_state_dict["state"]: {0: {'step': tensor(1.), 'exp_avg': tensor([-6.2440e-07, 6.8065e-07, -2.2726e-06, ..., 3.1220e-07, 1.6088e-06, -1.5047e-07], device='cuda:1'), 'exp_avg_sq': tensor([3.8987e-14, 4.6328e-14, 5.1646e-13, ..., 9.7467e-15, 2.5882e-13, 2.2642e-15], device='cuda:1')}...
So the problem is optim_state_dict
being empty when passed into _optim_state_dict()
.
- If we go further in the traceback:
_optim_state_dict_impl()
: pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py at v2.1.2 · pytorch/pytorch · GitHub
_optim_state_dict()
: pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py at v2.1.2 · pytorch/pytorch · GitHub
We see thatoptim_state_dict
is initiallyNone
:
2025-02-22 17:02:28,304 - root - DEBUG - @tcm: In FSDP.optim_state_dict(): optim_state_dict: None
and initialized through:
if optim_state_dict is None:
optim_state_dict = optim.state_dict()
- Now, when I go to transformers
trainer.py
and look at the following function:
transformers/src/transformers/trainer.py at v4.43.1 · huggingface/transformers · GitHub
we can see that when saving the fsdp optimizer, nooptim_state_dict
is passed:
save_fsdp_optimizer(self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir)
So in torch.distributed.fsdp._optim_utils.py
, the following function is used to initialize optim_state_dict
: pytorch/torch/optim/optimizer.py at main · pytorch/pytorch · GitHub
To start with, this is the architecture of the model that I’m trying to fine-tune:
FullyShardedDataParallel(
(_fsdp_wrapped_module): CambrianLlamaForCausalLM(
(model): CambrianLlamaModel(
(embed_tokens): Embedding(128256, 3072)
(layers): ModuleList(
(0-27): 28 x FullyShardedDataParallel(
(_fsdp_wrapped_module): LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=3072, out_features=3072, bias=False)
(k_proj): Linear(in_features=3072, out_features=1024, bias=False)
(v_proj): Linear(in_features=3072, out_features=1024, bias=False)
(o_proj): Linear(in_features=3072, out_features=3072, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
(up_proj): Linear(in_features=3072, out_features=8192, bias=False)
(down_proj): Linear(in_features=8192, out_features=3072, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
)
(norm): LlamaRMSNorm()
(rotary_emb): LlamaRotaryEmbedding()
(mm_projector): Sequential(
(0): Linear(in_features=1024, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=3072, bias=True)
)
(mm_projector_aux_0): Sequential(
(0): Linear(in_features=1152, out_features=1024, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=1024, out_features=1024, bias=True)
(3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(mm_projector_aux_1): Sequential(
(0): Linear(in_features=1536, out_features=1024, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=1024, out_features=1024, bias=True)
(3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(vision_sampler_0): VisionTokenSampler(
(layers): ModuleList(
(0-2): 3 x VisionCrossAttentionLayer(
(proj_context): Linear(in_features=1024, out_features=1024, bias=False)
(proj_in): Linear(in_features=2048, out_features=1024, bias=False)
(proj_out): MLP(
(linear_1): Linear(in_features=1024, out_features=1024, bias=False)
(act): GELU(approximate='none')
(linear_2): Linear(in_features=1024, out_features=1024, bias=False)
)
(norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(cross_attn): MultiKVCrossAttention(
(q_proj): Sequential(
(0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=1024, out_features=1024, bias=False)
)
(k_proj_0): Sequential(
(0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=1024, out_features=1024, bias=False)
)
(v_proj_0): Sequential(
(0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=1024, out_features=1024, bias=False)
)
(k_proj_1): Sequential(
(0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=1024, out_features=1024, bias=False)
)
(v_proj_1): Sequential(
(0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(1): Linear(in_features=1024, out_features=1024, bias=False)
)
(o_proj): Linear(in_features=1024, out_features=1024, bias=False)
)
)
)
)
(lm_head): Linear(in_features=3072, out_features=128256, bias=False)
)
)
)
There are two FSDP instances: the entire model and the LlamaDecoderLayer
layer. In my fine-tuning script, this is how I configure FSDP options when used in Trainer
:
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
Since the error is related to saving FSDP optimizer states, I would like to provide the following method in LLaVATrainer
class, which is a subclass from Huggingface Trainer
:
class LLaVATrainer(Trainer):
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
# pyre-fixme[16]: `Trainer` has no attribute `model`.
opt_model = self.model
# if self.args.unfreeze_mm_vision_tower:
# opt_model.get_model().vision_tower_aux_list = nn.ModuleList(opt_model.get_vision_tower_aux_list())
# self.param_to_name = map_params_to_module_names([opt_model])
# pyre-fixme[16]: `Trainer` has no attribute `optimizer`.
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
# pyre-fixme[16]: `Trainer` has no attribute `mm_projector_lr`.
assert not (self.args.mm_projector_lr and self.args.mm_vision_sampler_lr)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args
)
self.optimizer = optimizer_cls(
optimizer_grouped_parameters, **optimizer_kwargs
)
return self.optimizer
In the model codebase, the create_optimizer()
method above creates param groups and other params. When I print out the steps in this create_optimizer()
method, the output is as follows:
2025-02-22 17:01:20,658 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing: tensor([], device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
2025-02-22 17:01:20,695 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:
tensor([], device='cuda:1', dtype=torch.bfloat16, requires_grad=True)
...
2025-02-22 17:01:20,690 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing: tensor([ 0.0103, 0.0090, 0.0134, ..., 0.0049, -0.0025, -0.0052], device='cuda:2', dtype=torch.bfloat16, requires_grad=True)
2025-02-22 17:01:20,691 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing: tensor([-0.0099, -0.0302, -0.0054, ..., -0.0038, -0.0027, -0.0015],
device='cuda:3', dtype=torch.bfloat16, requires_grad=True)
As can be seen, on ranks 0 and 1, the LM head
layer wrapped in FSDP unit has no params in the tensor, but they are present on ranks 3 and 4. So I think the error might stem from this fsdp sharding where the same layer lm head
is sharded on ranks 2 and 3 but empty on ranks 0 and 1. Therefore, my question is:
I would like to ask why FSDP in Trainer
shards a layer such that it’s empty on certain ranks, possibly leading to the error above?
- I further investigated
optim.Optimizer.state_dict()
and noticed that for ranks 1, 2, 3, theself.state
dict is available:
2025-02-24 04:16:07,667 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {Parameter containing: tensor([-0.0099, -0.0302, -0.0054, ..., -0.0038, -0.0027, -0.0015],
device='cuda:3', requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([-2.1631e-11, -6.3862e-11, -3.7596e-11, ..., -1.6094e-12, 7.6394e-12, -6.6093e-12], device='cuda:3'), 'exp_avg_sq': tensor([4.6788e-23, 4.0783e-22, 1.4135e-22, ..., 2.5902e-25, 5.8360e-24, 4.3683e-24], device='cuda:3')}})
2025-02-24 04:16:07,667 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {Parameter containing: tensor([ 0.0103, 0.0090, 0.0134, ..., 0.0049, -0.0025, -0.0052],
device='cuda:2', requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([-1.9513e-08, 5.0979e-09, 1.5206e-08, ..., -2.0051e-10, 9.6136e-11, -1.6412e-10], device='cuda:2'), 'exp_avg_sq': tensor([3.8075e-17, 2.5989e-18, 2.3122e-17, ..., 4.0205e-21, 9.2421e-22, 2.6935e-21], device='cuda:2')}})
but in rank 0, the self.state
dict is empty:
2025-02-24 04:16:18,691 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {})
So I think this is the root cause of the error I’m asking all along. I don’t understand why for rank 0, the optimizer’s self.state
is empty unlike the other ranks.
I’m been trying my best to find out the root cause and fix this error but the codebase is large and complex, so I am seeking help from the community.
Thanks in advance.