Does FSDP have support for this? I only want to shard certain parts of the model. I know FullyShardedDataParallel
has an ignored_modules
arg but I don’t understand how that is supposed to work if one has nn.Linear layers that one sometimes wants to be sharded and sometimes not.
I was wondering if one could manually call torch.distributed.fsdp.wrap
on individual parts of the model but some initial tests seem to indicate that approach does not work (seems the root module needs to be of type FullyShardedDataParallel
)?
Related to that is the question of what happens to the outer FSDP unit during training. This is referenced in the FSDP advanced tutorial on transfomer wrapping policy. Does this get unloaded while other (child) FSDP units are running? Or is this always kept in unsharded state?