Hello, I’m trying to train my asr model with FullyShardedDataParallel.
I directly saved FSDP wrapped model like below.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model_state = model.state_dict() # model = FSDP(mymodel())
But when I load state my model, It return’s flattened parameters.
model = mymodel()
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
Unexpected key(s) in state_dict: "_fsdp_wrapped_module.flat_param"
And when I wrap model by FSDP, and try loading state, It seems only rank-0’s params have been saved.
Error(s) in loading state_dict for FullyShardedDataParallel:
size mismatch for _fsdp_wrapped_module.flat_param: copying a param with shape torch.Size() from checkpoint, the shape in current model is torch.Size().
Is there any way to save unwrapped model state?
FSDP keeps the model sharded outside of forward/backwards so that’s why you’re seeing that.
In order to both see the unwrapped and unsharded model state you can use the
summon_full_params context manager. Unfortunately this method is not available on PyTorch 11.1 and you’ll need to use nightly builds.
@Rohan_Varma you have more visibility on FSDP checkpointing, could you help on that point?
Thank you for your reply.
I installed Nightly build and try to test like below.
if rank == 0:
model_state = model.summon_full_params()
'model' : model_state,
summon_full_params is a context manager and should be used in a
Beyond that, you have to pass the unwrapped module to torch.save to skip unserializable objects like ProcessGroup.
if rank == 0:
'model' : model.module,
As you’re using nightly you can now use new
state_dict APIs that were designed for loading/saving: FullyShardedDataParallel — PyTorch 1.12 documentation