[FSDP] LoRA tensors seem to vanish into oblivion when using FullyShardedDataParallel and peft

Hello all,

We recently started using FSDP through the :hugs: Accelerate library and are running into weird issues when trying to train with LoRA from the :hugs: peft library.

I’ll describe the current issue I’m facing and will also discuss a few other things that I’ve tried doing.

Current issue

With my current setup, we can run the forward pass of an FSDP model with LoRA, but we cannot save the model.

Saving LoRA checkpoints seems to fail as the LoRA layers seemingly lose their dimensionality by the time that we call the function to save the checkpoint. To illustrate what I mean, please see the following error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ec2-user/training/src/instructlab/training/main_ds.py", line 805, in <module>
[rank0]:     main(args)
[rank0]:   File "/home/ec2-user/training/src/instructlab/training/main_ds.py", line 613, in main
[rank0]:     train(
[rank0]:   File "/home/ec2-user/training/src/instructlab/training/main_ds.py", line 462, in train
[rank0]:     save_hf_format_accelerate(
[rank0]:   File "/home/ec2-user/training/src/instructlab/training/utils.py", line 692, in save_hf_format_accelerate
[rank0]:     model.module.merge_adapter()
[rank0]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/peft/tuners/tuners_utils.py", line 550, in merge_adapter
[rank0]:     module.merge(adapter_names=adapter_names)
[rank0]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/peft/tuners/lora/layer.py", line 481, in merge
[rank0]:     delta_weight = self.get_delta_weight(active_adapter)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/peft/tuners/lora/layer.py", line 548, in get_delta_weight
[rank0]:     output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
[rank0]:                               ~~~~~~~~~^~~~~~~~~~
[rank0]: RuntimeError: inconsistent tensor size, expected tensor [4096] and src [11008] to have the same number of elements, but got 4096 and 11008 elements respectively

Which is a really strange error considering the context in how these tensors are initialized (from :hugs: peft:

    def update_layer(
        self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False
    ):

        # ...

        # Actual trainable parameters
        self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
        self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)

Inspecting the LoRA adapters at this stage via distributed.breakpoint() produces the following results:

> self.lora_A[adapter_name].weight.shape
(11006, 4)
> self.lora_B[adapter_name].weight.shape
(4, 4096)

But in the self.get_delta_weight method where it attempts to perform the matrix multiplication between these two Tensors, inspecting their shapes yields the following result:

> self.lora_A[active_adapter].weight.shape
(11006,)
> self.lora_B[active_adapter].weight.shape
(4096,)

Which seems to imply that (fan_in + fan_out) x (lora_rank - 1) parameters simply disappeared at some point. I tried searching for where this possibly could be happening but was unsuccessful.

Here’s my current setup:

PeftConfig:

def get_fsdp_config(args, model):
    # Third Party
    from accelerate.utils import FullyShardedDataParallelPlugin
    from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
    from peft.utils.other import fsdp_auto_wrap_policy

    fsdp_plugin = FullyShardedDataParallelPlugin(
        auto_wrap_policy=fsdp_auto_wrap_policy(model),
        limit_all_gathers=True,
        mixed_precision_policy=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy],
        cpu_offload=CPUOffload(args.cpu_offload_params_fsdp),
        use_orig_params=True,
        cpu_ram_efficient_loading=True,
        sync_module_states=False,
    )
    return fsdp_plugin


peft_config = LoraConfig(
    lora_alpha=args.lora_alpha,
    lora_dropout=args.lora_dropout,
    r=args.lora_r,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=args.lora_target_modules,
)
model = get_peft_model(model, peft_config)
accel_args = {
    "fsdp_plugin": get_fsdp_config(args, model),
}
accelerator = Accelerator(**accel_args, mixed_precision="bf16")
accelerator.even_batches = False
return accelerator

Other Things I’ve tried

use_orig_params=False leads to an error about data not allocated yet

With the current mixed precision settings and use_orig_params=False we cannot run training as for some reason we get a different about the data for the layers not being allocated yet:

[rank0]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/peft/tuners/lora/layer.py", line 548, in get_delta_weight
[rank0]:     output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]
[rank0]:                               ~~~~~~~~~^~~~~~~~~~
[rank0]: RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

use_orig_params=False and mixed_precision=None leads to data type mismatch between float32 and bfloat16

When we have this configuration, we run into a problem where 3 iterations into the forward pass, we have the inputs.dtype = float32 being passed into a Linear layer with weight.dtype = bfloat16:

[rank1]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/torch/nn/modules/linear.py", line 117, in forward
[rank1]:     return F.linear(input, self.weight, self.bias)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

LoRA mixed precision error

When use_orig_params=True and no mixed precision setting is given to FSDP, then we again run into the data type mismatch on the very first forward pass:

[rank1]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 858, in forward
[rank1]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/ec2-user/training/venv/lib64/python3.11/site-packages/torch/nn/modules/linear.py", line 117, in forward
[rank1]:     return F.linear(input, self.weight, self.bias)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16

Does anyone here have an idea for why this is happening? Any insight would be greatly appreciated!

Can you show me your code that you are using for saving your model ?

@pratikkorat26 Sure, here’s the function that saves checkpoints:

def save_hf_format_accelerate(
    args,
    model,
    tokenizer,
    accelerator: Accelerator,
    samples_seen,
    convert_granite=True,
    is_lora=False,
):
    log_rank_0(
        f"\033[93mSaving model in huggingface format at samples_seen: {samples_seen}\033[0m",
        to_print=True,
    )
    start = time.time()

    final_output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}"
    if args.is_granite and convert_granite:
        tmpdir = TemporaryDirectory("w")  # pylint: disable=consider-using-with
        output_dir = Path(tmpdir.name)
    else:
        output_dir = final_output_dir

    CONFIG_NAME = "config.json"
    output_config_file = output_dir / CONFIG_NAME

    get_state_dict_unpatched = accelerator.get_state_dict

    def _get_state_dict_patched(model, unwrap=False):
        return get_state_dict_unpatched(model, unwrap=unwrap)

    accelerator.get_state_dict = _get_state_dict_patched

    if accelerator.is_main_process:
        if is_lora:
            model.module.merge_adapter()
            model_state = model.module.state_dict()

        output_dir.mkdir(parents=True, exist_ok=True)
        if not model.module.config.architectures and convert_granite:
            model.module.config.architectures = ["LlamaForCausalLM"]
            warnings.warn(
                f"Adding architectures to ckpt: {model.module.config.architectures}",
            )
        model.module.config.to_json_file(output_config_file)
        tokenizer.save_pretrained(output_dir)

        if is_lora:
            save_dict_accelerate(
                accelerator,
                model_state,
                save_directory=output_dir,
                max_shard_size="5GB",
                safe_serialization=True,
            )
            model.module.unmerge_adapter()

    if not is_lora:
        accelerator.save_model(
            model,
            save_directory=output_dir,
            max_shard_size="5GB",
            safe_serialization=True,
        )

    if args.is_granite and convert_granite and accelerator.is_main_process:
        # export doesnt like the directory to exist
        if final_output_dir.exists():
            shutil.rmtree(final_output_dir)
        export_to_huggingface(
            pretrained_model_name_or_path=tmpdir.name,
            save_path=final_output_dir,
            model_type="llama",
        )
        tmpdir.cleanup()

    log_rank_0(f"\033[93mModel saved in {final_output_dir}\033[0m", to_print=True)
    log_rank_0(f"saving took {time.time() - start} seconds")
    dist.barrier()

    accelerator.get_state_dict = get_state_dict_unpatched

granite here just translates to padding-free transformers. The save logic here was written for ZeRO Stage-2 when we were using DeepSpeed, and the FSDP code I’ve been running has been with SHARD_GRAD_OP as the sharding strategy. I’ve also tried setting the state dict type to FULL_STATE_DICT with no success.