Efficiently Training Multiple Large Models with PyTorch FSDP: Best Practices?

Hi everyone,

I’m working on a reinforcement learning framework for training large language models (LLMs) and need to run three very large models simultaneously. The challenge is that these models are too big to fit into a single GPU, and I’m exploring ways to distribute them effectively across GPUs and nodes using PyTorch’s Fully Sharded Data Parallel (FSDP).

Here are my specific questions:

  1. Can PyTorch FSDP handle sharing a single GPU among these three models?
  2. If not, what would be the best approach to distribute the models across GPUs/nodes while ensuring efficient memory usage and communication?