Saving FSDP model via full_optim_state_dict

Hi,

Trying to find out why FSDP model does not want to save its optimizer state from within all processes in parallel:

from torch.distributed import fsdp
...
model = ... # torch.nn.Module
model = fsdp.FSDP(model)
optimizer = torch.optim.adam.AdamW(model)
...

state_dict_config = fsdp.FullStateDictConfig(
            offload_to_cpu=True, rank0_only=True)
state_dict_type = fsdp.StateDictType.FULL_STATE_DICT
with fsdp.FullyShardedDataParallel.state_dict_type(
            model,
            state_dict_type=state_dict_type,
            state_dict_config=state_dict_config):
    model_state = model.state_dict()

optimizer_state = fsdp.FullyShardedDataParallel.full_optim_state_dict(
                model=model, optim=optimizer, rank0_only=True)

fails with

File ".venv/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1331, in full_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File ".venv/lib/python3.8/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1154, in _optim_state_dict_impl
    return _optim_state_dict(
  File ".venv/lib/python3.8/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1455, in _optim_state_dict
    _gather_orig_param_state(
  File ".venv/lib/python3.8/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1697, in _gather_orig_param_state
    value = value[: flat_param._numels[param_idx]].reshape(
RuntimeError: shape '[50304, 768]' is invalid for input of size 8

Meaning that the model state is collected, but not the optimizer’s one. Which is somewhat strange given that it follows examples in pytorch tests or here with mosaic.

Any clue what this could be if the torch version is 2.0.0? Thank you very much in advance!

1 Like

It might be best for you to file a Github issue with a way to repro for us to best help out.

1 Like