[FSDP] Is there Any way to unwrap FSDP when I want to save module's state?

Hello, I’m trying to train my asr model with FullyShardedDataParallel.
I directly saved FSDP wrapped model like below.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

torch.distributed.barrier()
model_state = model.state_dict() # model = FSDP(mymodel())
if rank==0:
  torch.save({'model':model_state},  model_path)

But when I load state my model, It return’s flattened parameters.

model = mymodel()
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])

Unexpected key(s) in state_dict: "_fsdp_wrapped_module.flat_param"

And when I wrap model by FSDP, and try loading state, It seems only rank-0’s params have been saved.

Error(s) in loading state_dict for FullyShardedDataParallel:
	size mismatch for _fsdp_wrapped_module.flat_param: copying a param with shape torch.Size([2430345]) from checkpoint, the shape in current model is torch.Size([14582066]).

Is there any way to save unwrapped model state?

Thank you.

FSDP keeps the model sharded outside of forward/backwards so that’s why you’re seeing that.

In order to both see the unwrapped and unsharded model state you can use the summon_full_params context manager. Unfortunately this method is not available on PyTorch 11.1 and you’ll need to use nightly builds.

@Rohan_Varma you have more visibility on FSDP checkpointing, could you help on that point?

Thank you for your reply.

I installed Nightly build and try to test like below.

if rank == 0:
    model.clip_grad_norm_(1)
    model_state = model.summon_full_params()
    
    torch.save(
            {
                'model' : model_state,
            },
            model_path
        )

summon_full_params is a context manager and should be used in a with statement.
Beyond that, you have to pass the unwrapped module to torch.save to skip unserializable objects like ProcessGroup.

if rank == 0:
  with model.summon_full_params():
    torch.save( 
        {
            'model' : model.module,
         },
         model_path
     )

As you’re using nightly you can now use new state_dict APIs that were designed for loading/saving: FullyShardedDataParallel — PyTorch 1.12 documentation

1 Like