TP/FSDP + sync_module_states/cpu_offload

Hi, I’m working on writing a distributed training guide.

One of the chapters I’m working on currently is a chapter on 2d parallelism (tensor parallel + FSDP). The full code is in the PR here: [WIP] 2d parallelism chapter by corey-lambda · Pull Request #39 · LambdaLabsML/distributed-training-guide · GitHub

I have a couple of questions, but first here’s some background:

  1. I have a 64 GPU cluster (8 nodes, each with 8 H100 GPUs)
  2. I’m training Llama 3.1 405B
  3. I load the pretrained weights on rank 0 on the CPU, and every other rank uses the meta device
  4. When just using fsdp (you can see the full code for this here: distributed-training-guide/06-training-llama-405b/train_llm.py at main · LambdaLabsML/distributed-training-guide · GitHub), I used sync_module_states=True and CPUOffload(offload_params=True)

For 2d parallelism:

  1. I’m using the pytorch example code from the TP/FSDP example as a reference (can’t include link for some reason, but I can share in a comment)
  2. I have a 2d device mesh (size 8 in dp and size 8 in tp dimension)
  3. I’m applying tensor parallel to all the weights of the model along the TP dimension of the 2d mesh
  4. I’m applying FSDP along the DP dimension of the 2d mesh

Now for my questions:

  1. I keep getting OOM errors when FSDP starts materializing the tensors. I’m wondering if CPU Offloading works with DTensors from tensor parallel (maybe when combined with meta device?)
  2. How does sync_module_states=True work with FSDP & device_mesh? Since the weights are not really shared between the ranks anymore (they are sharded in device mesh groups), do I need to do something else to ensure that all the ranks have the correct weights?

Another update:

So since I moved to using TP, I removed the auto wrapping policy from the FSDP constructor. That means the only module that is wrapped by FSDP is the top level module.

Interestingly, parameters during initialization are batch moved between devices instead of 1 by 1. So since there is only 1 FSDP here, all 405b params (split into 8 TP shards) are moved directly to GPU device when calling param_init_fn: pytorch/torch/distributed/fsdp/_init_utils.py at v2.4.1 · pytorch/pytorch · GitHub.

Since I’m using 8 H100 nodes I have a total of 640 GB of GPU memory on a node, but 405b bf16 params take ~800 GB of gpu memory, so that’s why I need CPU offloading or staged device transfer.

I changed my param_init_fn to initialize on CPU device instead, but am getting another error now about sync module states. I may need to bring back auto_wrap_policy just for staged cpu offloading.

I also think I need to disable sync_module_states. With TP, rank 0 only has a shard of the model parameters, not the full set, so it can’t actually share with all the ranks.

Instead I think I need to have all the nodes load the weights into CPU at first and not use the meta device. Or figure out some way for all node 0 (rank 0-8) to broadcast weights to the other nodes (i.e. rank 1 would broadcast to rank 9/17/etc).

Do you think it would be possible to use FSDP2? The TP support should be simpler and better. GitHub - pytorch/torchtitan: A native PyTorch Library for large model training

Okay I will check it out - I’m learning that TP support in pytorch is still evolving rapidly from reading through some of the stuff in torchtitan. Thanks for linking that.