I have been using FSDP on 1.11 stable, checkpointing on each device with the wrapped model: model.state_dict(). When loading the checkpoint I would initialize the model, wrap with FSDP then do load_state_dict on the wrapped model on each device, and training would resume successfully.
I’m looking to do some evals and can probably fit the entire model on a single device, and therefore need to unshard things.
Are these checkpoints usable with any of the apis available in 1.12? load_state_dict, load_local_state_dict, or load_shared_state_dict?
If you save your model that way you won’t be able to change the cluster topology later.
There are essentially two recommended ways to checkpoint a FSDP model:
Save the whole model from rank0:
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = ...
# Save
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, "checkpoint.pt")
# Load
StateDictType.FULL_STATE_DICT):
model.load_state_dict(torch.load("checkpoint.pt"))
Save in a distributed fashion using the experimental distributed checkpointing API:
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.distributed._shard.checkpoint as dist_cp
model = ...
# Save
with FSDP.state_dict_type(model,
StateDictType.SHARDED_STATE_DICT):
checkpoint = model.state_dict()
dist_cp.save_state_dict(
state_dict=checkpoint,
storage_writer=dist_cp.FileSystemWriter("checkpoint")
)
# Load
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
checkpoint = model.state_dict()
dist_cp.load_state_dict(
state_dict=checkpoint,
storage_reader=dist_cp.FileSystemReader("checkpoint")
)
model.load_state_dict(checkpoint)
While you can use StateDictType.LOCAL_STATE_DICT and checkpoint in a similar way to how you’re doing today, it won’t handle the scenario of changing your cluster topology (IE, go from N ranks to single rank).
If you have further questions, and have a slack account, feel free to ping me with any further questions on checkpointing - I’d love to hear about your experience with it.
After running this function on a model you can call state_dict(), and remove the “_fsdp_wrapped_module._fpw_module.” from the keys to produce something that looks like an original model checkpoint, however my testing indicates the parameters are different so I’m not sure it can be relied upon. I will do a little more testing.
For my smaller models I can retrain fairly cheaply and am using the FULL_STATE_DICT which works very nicely!