Hi,
When wrapping a model like:
fsdp_model = FullyShardedDataParallel(
model(),
fsdp_auto_wrap_policy=default_auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
)
Using summon_full_params(model) will unshard all parameters for all wrapped modules which will result in the full model in each RANK, causing OOM in case of a large model. I know that I can wrap modules individually like this:
class model(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = wrap(nn.Linear(8, 4))
self.layer2 = nn.Linear(4, 16)
self.layer3 = wrap(nn.Linear(16, 4))
wrapper_kwargs = Dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
fsdp_model = wrap(model())
But still, the model will be the root_module, and summon_full_params()
expects the root_module to be able to unshard the parameters.
My question is, is it correct, safe and optimal to do:
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
fsdp_layer1 = wrap(model.layer1)
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
fsdp_layer3 = wrap(model.layer3)
Then apply summon_full_params on those individual modules?
Thanks!