Hello, I’m trying to train my asr model with FullyShardedDataParallel.
I directly saved FSDP wrapped model like below.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
torch.distributed.barrier()
model_state = model.state_dict() # model = FSDP(mymodel())
if rank==0:
torch.save({'model':model_state}, model_path)
But when I load state my model, It return’s flattened parameters.
model = mymodel()
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])
Unexpected key(s) in state_dict: "_fsdp_wrapped_module.flat_param"
And when I wrap model by FSDP, and try loading state, It seems only rank-0’s params have been saved.
Error(s) in loading state_dict for FullyShardedDataParallel:
size mismatch for _fsdp_wrapped_module.flat_param: copying a param with shape torch.Size([2430345]) from checkpoint, the shape in current model is torch.Size([14582066]).
FSDP keeps the model sharded outside of forward/backwards so that’s why you’re seeing that.
In order to both see the unwrapped and unsharded model state you can use the summon_full_params context manager. Unfortunately this method is not available on PyTorch 11.1 and you’ll need to use nightly builds.
@Rohan_Varma you have more visibility on FSDP checkpointing, could you help on that point?
summon_full_params is a context manager and should be used in a with statement.
Beyond that, you have to pass the unwrapped module to torch.save to skip unserializable objects like ProcessGroup.
if rank == 0:
with model.summon_full_params():
torch.save(
{
'model' : model.module,
},
model_path
)