Migrating From 1.11 to 1.12 FSDP

I have been using FSDP on 1.11 stable, checkpointing on each device with the wrapped model: model.state_dict(). When loading the checkpoint I would initialize the model, wrap with FSDP then do load_state_dict on the wrapped model on each device, and training would resume successfully.

I’m looking to do some evals and can probably fit the entire model on a single device, and therefore need to unshard things.

Are these checkpoints usable with any of the apis available in 1.12? load_state_dict, load_local_state_dict, or load_shared_state_dict?

Thank You

cc @Yanli_Zhao @rvarm1 @agu for FSDP related question

If you save your model that way you won’t be able to change the cluster topology later.

There are essentially two recommended ways to checkpoint a FSDP model:

Save the whole model from rank0:

import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = ...
# Save
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
  state_dict = model.state_dict()
  if dist.get_rank() == 0:
     torch.save(state_dict, "checkpoint.pt")

# Load
StateDictType.FULL_STATE_DICT):
    model.load_state_dict(torch.load("checkpoint.pt"))

Save in a distributed fashion using the experimental distributed checkpointing API:

import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.distributed._shard.checkpoint as dist_cp

model = ...
# Save
with FSDP.state_dict_type(model, 
StateDictType.SHARDED_STATE_DICT):
    checkpoint = model.state_dict()
    dist_cp.save_state_dict(
        state_dict=checkpoint,
        storage_writer=dist_cp.FileSystemWriter("checkpoint")
        )

# Load
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    checkpoint = model.state_dict()
    dist_cp.load_state_dict(
        state_dict=checkpoint,
        storage_reader=dist_cp.FileSystemReader("checkpoint")
    )
    model.load_state_dict(checkpoint)

While you can use StateDictType.LOCAL_STATE_DICT and checkpoint in a similar way to how you’re doing today, it won’t handle the scenario of changing your cluster topology (IE, go from N ranks to single rank).

If you have further questions, and have a slack account, feel free to ping me with any further questions on checkpointing - I’d love to hear about your experience with it.

1 Like

Ok thank you.

I did find something in the 1.11 tests that might have recovered the params: pytorch/common_fsdp.py at bc2c6edaf163b1a1330e37a6e34caf8c553e4755 · pytorch/pytorch · GitHub

After running this function on a model you can call state_dict(), and remove the “_fsdp_wrapped_module._fpw_module.” from the keys to produce something that looks like an original model checkpoint, however my testing indicates the parameters are different so I’m not sure it can be relied upon. I will do a little more testing.

For my smaller models I can retrain fairly cheaply and am using the FULL_STATE_DICT which works very nicely!