Tensor shape mismatch error when doing an allgather in distributed training with FSDP

Hi,

I’m finetuning a multimodal LLM and during this process, I encounter the following error when attempting to save the checkpoint. More particularly, I can save the model normally but when the optimizer states are saved, the following error occurs:

RuntimeError: Detected mismatch between collectives on ranks. Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=549, OpType=ALLGATHER, TensorShape=[0], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 2 is running collective: CollectiveFingerPrint(SequenceNumber=549, OpType=ALLGATHER, TensorShape=[183971584], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects:   Tensor Tensor shapes: 0vs 183971584

This is the full traceback of tensor shape mismatch when saving fsdp optimizer states:

File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2356, in _inner_training_loop
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2807, in _maybe_log_save_evaluate
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2890, in _save_checkpoint
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 3001, in _save_optimizer_and_scheduler
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 185, in save_fsdp_optimizer
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1828, in optim_state_dict
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1253, in _optim_state_dict_impl
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1396, in _optim_state_dict
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1657, in _gather_orig_param_state
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1593, in _all_gather_optim_state

The packages version is:

torch==2.1.2
numpy==1.26.4
transformers==4.43.1

This is the full details I trace the error in functions listed in the traceback. It’s quite long but it leads to what I believe a possible reason for error above and my question below:

Summary

I began digging into the codebase, starting with torch.distributed.fsdp, to find out the cause as follows:

work = dist.all_gather(
            tensors, local_state, group=fsdp_state.process_group, async_op=True
        )

tensors is a list where certain elements are 0 and others are non-zero, as raised by the error above. When I printed out object_state.tensors:

2025-02-22 17:02:42,065 - root - DEBUG - rank 0, object_state.tensors: {}, name: exp_avg, info: None
2025-02-22 17:02:42,065 - root - DEBUG - rank 1, object_state.tensors: {}, name: exp_avg, info: None
2025-02-22 17:02:42,066 - root - DEBUG - rank 2, object_state.tensors: {'exp_avg': _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32), 'exp_avg_sq': _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32)}, name: exp_avg, info: _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32)                                                                                            
2025-02-22 17:02:42,066 - root - DEBUG - rank 3, object_state.tensors: {'exp_avg': _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32), 'exp_avg_sq': _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32)}, name: exp_avg, info: _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32)

It can be seen that on the 2 ranks 0 and 1, the tensors are empty. Since these object_state.tensors are gathered in object_list from processes in the process group via:

dist.all_gather_object(object_list, processed_state, group=fsdp_state.process_group)

It seems that the processed_state for ranks 0 and 1 is empty (StateInfo({}, {}, {})). And this is caused by empty optim_state if you look at the for loop in the beginning of _all_gather_optim_state().

2025-02-22 17:02:42,060 - root - DEBUG - @tcm: In _all_gather_optim_state(): optim_state: {}
2025-02-22 17:02:42,060 - root - DEBUG - @tcm: In _gather_orig_param_state(): optim_state: {}
# Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
    # across ranks
    for optim_state_key in all_optim_state_keys:
        param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
            optim_state_key, None
        )

        if param_key is None:
            assert use_orig_params, (
                "If use_orig_params is False, we must be able to find the "
                f"corresponding param id. {optim_state_key} {param_key}"
            )
            if not optim_state_key.is_fsdp_managed:
                continue

        if optim_state_key.is_fsdp_managed:
            # If there are multiple unflat_param_names (not use_orig_params),
            # they share the same FSDPParamInfo. So the first unflat_param_name
            # is sufficient to fetch the FSDPParamInfo.
            fqn = optim_state_key.unflat_param_names[0]
            fsdp_param_info = fqn_to_fsdp_param_info[fqn]
            if use_orig_params:
                state = (
                    {} if param_key is None else optim_state_dict["state"][param_key]
                )
                unflat_state = [
                    _gather_orig_param_state(
                        fsdp_param_info,
                        fqn,
                        state,
                        shard_state,
                    )
                ]

The problem is that param_key is None which leads to empty ‘state’ when passed to _gather_orig_param_state():

state = ({} if param_key is None else optim_state_dict["state"][param_key])

The param_key is None because optim_state_key_to_param_key dictionary is empty:

2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {} # rank 0 or 1
2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {_OptimStateKey(unflat_param_names=('lm_head.weight',), is_fsdp_managed=True): 2} # rank 2 or 3
2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {_OptimStateKey(unflat_param_names=('model.mm_projector.0.weight',), is_fsdp_managed=True): 0, _OptimStateKey(unflat_param_names=('model.mm_projector.2.weight',), is_fsdp_managed=True): 1, _OptimStateKey(unflat_param_names=('model.mm_projector.0.bias',), is_fsdp_managed=True): 3, _OptimStateKey(unflat_param_names=('model.mm_projector.2.bias',), is_fsdp_managed=True): 4} # rank 2 or 3

To understand why the dict optim_state_key_to_param_key is empty, I looked into the function: pytorch/torch/distributed/fsdp/_optim_utils.py at v2.1.2 · pytorch/pytorch · GitHub
Here, if we look at the for loop in the beginning:

for param_key, param in param_key_to_param.items():
        # Do not include parameters without state to avoid empty mappings
        # just like in normal `torch.optim.Optimizer.state_dict()`
        if param_key not in optim_state_dict["state"]:
            continue

optim_state_dict["state"] is empty so the iteration is skipped, causing optim_state_key_to_param_key to not be updated.

2025-02-22 17:02:42,041 - root - DEBUG - @tcm: In _map_param_key_to_optim_keys(): optim_state_dict["state"]: {} # in empty ranks such as 0
2025-02-22 17:02:28,436 - root - DEBUG - @tcm: In _map_param_key_to_optim_keys(): optim_state_dict["state"]: {0: {'step': tensor(1.), 'exp_avg': tensor([-6.2440e-07,  6.8065e-07, -2.2726e-06,  ...,  3.1220e-07,          1.6088e-06, -1.5047e-07], device='cuda:1'), 'exp_avg_sq': tensor([3.8987e-14, 4.6328e-14, 5.1646e-13,  ..., 9.7467e-15, 2.5882e-13, 2.2642e-15],  device='cuda:1')}...

So the problem is optim_state_dict being empty when passed into _optim_state_dict().

2025-02-22 17:02:28,304 - root - DEBUG - @tcm: In FSDP.optim_state_dict(): optim_state_dict: None

and initialized through:

if optim_state_dict is None:
            optim_state_dict = optim.state_dict()
save_fsdp_optimizer(self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir)

So in torch.distributed.fsdp._optim_utils.py, the following function is used to initialize optim_state_dict: pytorch/torch/optim/optimizer.py at main · pytorch/pytorch · GitHub

To start with, this is the architecture of the model that I’m trying to fine-tune:

FullyShardedDataParallel(
  (_fsdp_wrapped_module): CambrianLlamaForCausalLM(
    (model): CambrianLlamaModel(
      (embed_tokens): Embedding(128256, 3072)
      (layers): ModuleList(
        (0-27): 28 x FullyShardedDataParallel(
          (_fsdp_wrapped_module): LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
              (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
              (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
              (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
              (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
          )
        )
      )
      (norm): LlamaRMSNorm()
      (rotary_emb): LlamaRotaryEmbedding()
      (mm_projector): Sequential(
        (0): Linear(in_features=1024, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=3072, bias=True)
      )
      (mm_projector_aux_0): Sequential(
        (0): Linear(in_features=1152, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=1024, bias=True)
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (mm_projector_aux_1): Sequential(
        (0): Linear(in_features=1536, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=1024, bias=True)
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (vision_sampler_0): VisionTokenSampler(
        (layers): ModuleList(
          (0-2): 3 x VisionCrossAttentionLayer(
            (proj_context): Linear(in_features=1024, out_features=1024, bias=False)
            (proj_in): Linear(in_features=2048, out_features=1024, bias=False)
            (proj_out): MLP(
              (linear_1): Linear(in_features=1024, out_features=1024, bias=False)
              (act): GELU(approximate='none')
              (linear_2): Linear(in_features=1024, out_features=1024, bias=False)
            )
            (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (cross_attn): MultiKVCrossAttention(
              (q_proj): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (k_proj_0): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (v_proj_0): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (k_proj_1): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (v_proj_1): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
            )
          )
        )
      )
      (lm_head): Linear(in_features=3072, out_features=128256, bias=False)
    )
  )
)

There are two FSDP instances: the entire model and the LlamaDecoderLayer layer. In my fine-tuning script, this is how I configure FSDP options when used in Trainer:

--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \

Since the error is related to saving FSDP optimizer states, I would like to provide the following method in LLaVATrainer class, which is a subclass from Huggingface Trainer:

class LLaVATrainer(Trainer):
    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        # pyre-fixme[16]: `Trainer` has no attribute `model`.
        opt_model = self.model
        # if self.args.unfreeze_mm_vision_tower:
        #     opt_model.get_model().vision_tower_aux_list = nn.ModuleList(opt_model.get_vision_tower_aux_list())
        #     self.param_to_name = map_params_to_module_names([opt_model])
        # pyre-fixme[16]: `Trainer` has no attribute `optimizer`.
        if self.optimizer is None:
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
            # pyre-fixme[16]: `Trainer` has no attribute `mm_projector_lr`.
            assert not (self.args.mm_projector_lr and self.args.mm_vision_sampler_lr)
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p
                            for n, p in opt_model.named_parameters()
                            if (n in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        "params": [
                            p
                            for n, p in opt_model.named_parameters()
                            if (n not in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                    },
                ]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
                self.args
            )

            self.optimizer = optimizer_cls(
                optimizer_grouped_parameters, **optimizer_kwargs
            )
        return self.optimizer

In the model codebase, the create_optimizer() method above creates param groups and other params. When I print out the steps in this create_optimizer() method, the output is as follows:

2025-02-22 17:01:20,658 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:                                                                                 tensor([], device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
2025-02-22 17:01:20,695 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:
tensor([], device='cuda:1', dtype=torch.bfloat16, requires_grad=True)
...
2025-02-22 17:01:20,690 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:                                                                                 tensor([ 0.0103,  0.0090,  0.0134,  ...,  0.0049, -0.0025, -0.0052],                                                                                                                                                      device='cuda:2', dtype=torch.bfloat16, requires_grad=True)
2025-02-22 17:01:20,691 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:                                                                                 tensor([-0.0099, -0.0302, -0.0054,  ..., -0.0038, -0.0027, -0.0015],
       device='cuda:3', dtype=torch.bfloat16, requires_grad=True)

As can be seen, on ranks 0 and 1, the LM head layer wrapped in FSDP unit has no params in the tensor, but they are present on ranks 3 and 4. So I think the error might stem from this fsdp sharding where the same layer lm head is sharded on ranks 2 and 3 but empty on ranks 0 and 1. Therefore, my question is:

I would like to ask why FSDP in Trainer shards a layer such that it’s empty on certain ranks, possibly leading to the error above?

  1. I further investigated optim.Optimizer.state_dict()and noticed that for ranks 1, 2, 3, the self.state dict is available:
2025-02-24 04:16:07,667 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {Parameter containing:                                                                      tensor([-0.0099, -0.0302, -0.0054,  ..., -0.0038, -0.0027, -0.0015],
       device='cuda:3', requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([-2.1631e-11, -6.3862e-11, -3.7596e-11,  ..., -1.6094e-12,                                                                              7.6394e-12, -6.6093e-12], device='cuda:3'), 'exp_avg_sq': tensor([4.6788e-23, 4.0783e-22, 1.4135e-22,  ..., 2.5902e-25, 5.8360e-24,                                                                               4.3683e-24], device='cuda:3')}})
2025-02-24 04:16:07,667 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {Parameter containing:                                                                      tensor([ 0.0103,  0.0090,  0.0134,  ...,  0.0049, -0.0025, -0.0052],
       device='cuda:2', requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([-1.9513e-08,  5.0979e-09,  1.5206e-08,  ..., -2.0051e-10,                                                                              9.6136e-11, -1.6412e-10], device='cuda:2'), 'exp_avg_sq': tensor([3.8075e-17, 2.5989e-18, 2.3122e-17,  ..., 4.0205e-21, 9.2421e-22,                                                                               2.6935e-21], device='cuda:2')}})

but in rank 0, the self.state dict is empty:

2025-02-24 04:16:18,691 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {})

So I think this is the root cause of the error I’m asking all along. I don’t understand why for rank 0, the optimizer’s self.state is empty unlike the other ranks.

I’m been trying my best to find out the root cause and fix this error but the codebase is large and complex, so I am seeking help from the community.

Thanks in advance.