I find this problem. When I use fsdp and three gpus to fine-tune llama3 vision model, I find every gpu(fsdp) has a higher memory than others(every gpu in ddp and single gpu). I get OOM error and it happens in _flat_param.py when flattening tensors. What’s more, even I can load model when using fsdp, it has higher memory than ddp.
It may be a bug in pytorch?
Did you try FSDP2
, as it’s supposed to use less memory? From the RFC:
We have validated the prototype on Llama-like models, achieving on-par throughput while using less memory.
1 Like
I only have used FSDP because I refer from llama-recipes to write my own finetuning.py. Thank you for your suggestion, I will try it!
Thanks for reply. I tried FSDP2 and it definitely has better compatibility with other features, like activation checkpointing and torch compile.
I can see my gpu mem saving now.