How does fsdp algorithm work?

Hi
I’m trying to understand fsdp but can’t find any good resources. I’ve broken it down into steps can someone tell me if I’m right??
So lets say that I have 4 gpu’s and a model that can’t fit into one gpu so it gets split into 2 shards.

  1. load weights from shard 1 into each gpu
  2. split batch of data into 4 parts
  3. do forward pass with part 1 on gpu 1, part 2 on gpu 2 etc.
  4. save intermediate outputs for backprop locally. They are different on each gpu because of different data
  5. offload weights from shard 1
  6. load weights from shard 2 into each gpu
  7. do forward pass with outputs from step 4 and save intermediate outputs. Again they are different on each gpu because of different inputs. outputs from step 4 are still saved locally
  8. calculate loss
  9. load optimizer state regarding weights from shard 2
  10. calculate gradients for weights from shard 2. They are different on each gpu
  11. synchronize gradients so they are the same on all gpus
  12. update weights on each gpu locally. after update each gpu has the same weights
  13. offload optimizer state and weights from shard 2
  14. load weights from shard 1 and optimizer state regarding them.
  15. calculate gradients for weights from shard 1. They are different on each gpu
  16. synchronize gradients so they are the same on all gpus
  17. update weights on each gpu locally. after update each gpu has the same weights
  18. offload optimizer state
  19. load another batch of data and split into 4 parts
  20. repeat steps 3-19 until stopping criteria is met

Am I right? Would fsdp work like that in a situation I described?

This tutorial and these 10 videos might give you the details about its implementation.

If you still have specific FSDP questions after the resources linked, I can try to help.