While using FSDP, how to unshard single modules parameters outside forward/backward?

Hi,

When wrapping a model like:

fsdp_model = FullyShardedDataParallel(
   model(),
   fsdp_auto_wrap_policy=default_auto_wrap_policy,
   cpu_offload=CPUOffload(offload_params=True),
)

Using summon_full_params(model) will unshard all parameters for all wrapped modules which will result in the full model in each RANK, causing OOM in case of a large model. I know that I can wrap modules individually like this:

class model(nn.Module):
   def __init__(self):
       super().__init__()
       self.layer1 = wrap(nn.Linear(8, 4))
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = wrap(nn.Linear(16, 4))
 
wrapper_kwargs = Dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
   fsdp_model = wrap(model())

But still, the model will be the root_module, and summon_full_params() expects the root_module to be able to unshard the parameters.

My question is, is it correct, safe and optimal to do:

with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
   fsdp_layer1 = wrap(model.layer1)

with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
   fsdp_layer3 = wrap(model.layer3)

Then apply summon_full_params on those individual modules?
Thanks!