Hi
I’m trying to understand fsdp but can’t find any good resources. I’ve broken it down into steps can someone tell me if I’m right??
So lets say that I have 4 gpu’s and a model that can’t fit into one gpu so it gets split into 2 shards.
- load weights from shard 1 into each gpu
- split batch of data into 4 parts
- do forward pass with part 1 on gpu 1, part 2 on gpu 2 etc.
- save intermediate outputs for backprop locally. They are different on each gpu because of different data
- offload weights from shard 1
- load weights from shard 2 into each gpu
- do forward pass with outputs from step 4 and save intermediate outputs. Again they are different on each gpu because of different inputs. outputs from step 4 are still saved locally
- calculate loss
- load optimizer state regarding weights from shard 2
- calculate gradients for weights from shard 2. They are different on each gpu
- synchronize gradients so they are the same on all gpus
- update weights on each gpu locally. after update each gpu has the same weights
- offload optimizer state and weights from shard 2
- load weights from shard 1 and optimizer state regarding them.
- calculate gradients for weights from shard 1. They are different on each gpu
- synchronize gradients so they are the same on all gpus
- update weights on each gpu locally. after update each gpu has the same weights
- offload optimizer state
- load another batch of data and split into 4 parts
- repeat steps 3-19 until stopping criteria is met
Am I right? Would fsdp work like that in a situation I described?