How to use FSDP with LoRA?

I am trying to use Microsoft’s loralib: GitHub - microsoft/LoRA: Code for loralib, an implementation of "LoRA: Low-Rank Adaptation of Large Language Models" inside of an FSDP-wrapped model. However, this does not work because the parameter tensors: LoRA/loralib/layers.py at 4c0333854cb905966f8cc4e9a74068c1e507c7b7 · microsoft/LoRA · GitHub, get sharded, which causes issues when computing the LoRA update: LoRA/loralib/layers.py at 4c0333854cb905966f8cc4e9a74068c1e507c7b7 · microsoft/LoRA · GitHub.

Ideally – I would like to avoid using FSDP on the matrices that are being lora tuned (in this case, the q/k/v projections in a transformer).

How can I do this?

@weifengpy may be the right person to help

Ideally – I would like to avoid using FSDP on the matrices that are being lora tuned

is this from torch.distributed.fsdp.FullyShardedDataParalle? If yes, you can specify ignored_states to not shard qkv projections. example code: pytorch/test/distributed/fsdp/test_fsdp_ignored_modules.py at dbe6fce185afc3a59f7a2063289c809f39d36c32 · pytorch/pytorch · GitHub

I am also curious about your wrapping policy. if wrapping transformer layer, FSDP should unshard layer.parameters() before layer.forward. Technically it’s ok to shard qkv with lora adapters. Here is an example of LoRA + FSDP2 in TorchTune: torchtune/README.md at main · pytorch/torchtune · GitHub

1 Like