I’m using the pytorch example code from the TP/FSDP example as a reference (can’t include link for some reason, but I can share in a comment)
I have a 2d device mesh (size 8 in dp and size 8 in tp dimension)
I’m applying tensor parallel to all the weights of the model along the TP dimension of the 2d mesh
I’m applying FSDP along the DP dimension of the 2d mesh
Now for my questions:
I keep getting OOM errors when FSDP starts materializing the tensors. I’m wondering if CPU Offloading works with DTensors from tensor parallel (maybe when combined with meta device?)
How does sync_module_states=True work with FSDP & device_mesh? Since the weights are not really shared between the ranks anymore (they are sharded in device mesh groups), do I need to do something else to ensure that all the ranks have the correct weights?
So since I moved to using TP, I removed the auto wrapping policy from the FSDP constructor. That means the only module that is wrapped by FSDP is the top level module.
Since I’m using 8 H100 nodes I have a total of 640 GB of GPU memory on a node, but 405b bf16 params take ~800 GB of gpu memory, so that’s why I need CPU offloading or staged device transfer.
I changed my param_init_fn to initialize on CPU device instead, but am getting another error now about sync module states. I may need to bring back auto_wrap_policy just for staged cpu offloading.
I also think I need to disable sync_module_states. With TP, rank 0 only has a shard of the model parameters, not the full set, so it can’t actually share with all the ranks.
Instead I think I need to have all the nodes load the weights into CPU at first and not use the meta device. Or figure out some way for all node 0 (rank 0-8) to broadcast weights to the other nodes (i.e. rank 1 would broadcast to rank 9/17/etc).
Okay I will check it out - I’m learning that TP support in pytorch is still evolving rapidly from reading through some of the stuff in torchtitan. Thanks for linking that.