Purpose and communication of set reshard_after_forward=int in fsdp2

Hello everyone, when using PyTorch FSDP, I have several questions about setting reshard_after_forward to an integer (e.g. 8):

  • What is purpose of specifying a shard size as an integer compared to the Boolean modes (True/False)? Is this practice reasonable, and why not simply start training with that shard size from the beginning?

  • Communication flow for backward & the next forward pass, assuming initial sharding size is 32, and set reshard_after_forward=8 :

    • After the forward pass, each rank will have 1/8 params of model

    • During backward, how are all-gather and reduce_scatter executed within those 8-GPU groups? Or only all-gather is done within 8 GPUs, reduce_scatter and all_reduce are done within 32 GPUs

    • For the next forward pass, does the original 32-shard layout is required again?